1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17#include "tensorflow/core/kernels/training_ops.h"
18
19#include <algorithm> // NOLINT
20
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/kernels/training_op_helpers.h"
25#include "tensorflow/core/kernels/variable_ops.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/platform/bfloat16.h"
28#include "tensorflow/core/util/util.h"
29
30namespace tensorflow {
31
32using CPUDevice = Eigen::ThreadPoolDevice;
33using GPUDevice = Eigen::GpuDevice;
34using Index = Eigen::Index;
35
36namespace {
37template <class T>
38inline T sgn(const T x) {
39 T zero(0);
40 T one(1);
41 return (x == zero ? zero : (x < zero ? -one : one));
42}
43} // namespace
44
45namespace functor {
46template <typename T>
47struct ApplyGradientDescent<CPUDevice, T> {
48 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
49 typename TTypes<T>::ConstScalar lr,
50 typename TTypes<T>::ConstFlat grad) {
51 var.device(d) -= grad * lr();
52 }
53};
54
55template <typename T>
56struct ApplyAdadelta<CPUDevice, T> {
57 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
58 typename TTypes<T>::Flat accum,
59 typename TTypes<T>::Flat accum_update,
60 typename TTypes<T>::ConstScalar lr,
61 typename TTypes<T>::ConstScalar rho,
62 typename TTypes<T>::ConstScalar epsilon,
63 typename TTypes<T>::ConstFlat grad) {
64 accum.device(d) =
65 accum * rho() + grad.square() * (static_cast<T>(1) - rho());
66 const auto update =
67 (accum_update + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad;
68 var.device(d) -= update * lr();
69 accum_update.device(d) =
70 accum_update * rho() + update.square() * (static_cast<T>(1) - rho());
71 }
72};
73
74template <typename T>
75struct ApplyProximalGradientDescent<CPUDevice, T> {
76 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
77 typename TTypes<T>::ConstScalar lr,
78 typename TTypes<T>::ConstScalar l1,
79 typename TTypes<T>::ConstScalar l2,
80 typename TTypes<T>::ConstFlat grad) {
81 // Note that here is Fobos update, for details please refer:
82 // http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf
83 // TODO(xbing): merge the logic for ProximalGradientDescent and
84 // ProximalAdagrad.
85 auto prox_var = var;
86 // compute v = w - lr * grad.
87 prox_var.device(d) -= grad * lr();
88 if (l1() > 0) {
89 // compute sign(v) * max(|v| - lr * l1, 0)
90 var.device(d) =
91 prox_var.sign() *
92 (prox_var.abs() - var.constant(lr() * l1())).cwiseMax(T(0.0)) /
93 (var.constant(1.0) + var.constant(l2() * lr()));
94 } else {
95 var.device(d) =
96 prox_var / (var.constant(1.0) + var.constant(l2() * lr()));
97 }
98 }
99};
100
101template <typename T>
102struct ApplyAdagradDA<CPUDevice, T> {
103 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
104 typename TTypes<T>::Flat gradient_accum,
105 typename TTypes<T>::Flat gradient_squared_accum,
106 typename TTypes<T>::ConstScalar lr, int64_t global_step,
107 typename TTypes<T>::ConstScalar l1,
108 typename TTypes<T>::ConstScalar l2,
109 typename TTypes<T>::ConstFlat grad) {
110 // Accumulate gradient, and gradient_squared
111 gradient_accum.device(d) += grad;
112 gradient_squared_accum.device(d) += grad.square();
113
114 // AdagradDA update:
115 // Let g to be gradient accumulator, gg to be gradient squared accumulator,
116 // T be the global step, lr is the learning rate, and k the initial
117 // gradient squared accumulator value.
118 // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
119 if (l1() > 0) {
120 var.device(d) =
121 lr() * var.constant(-1.0) * gradient_accum.sign() *
122 (gradient_accum.abs() -
123 var.constant(static_cast<float>(global_step)) * var.constant(l1()))
124 .cwiseMax(T(0.0)) /
125 (var.constant(l2()) *
126 var.constant(static_cast<float>(global_step) * lr()) +
127 gradient_squared_accum.sqrt());
128 } else {
129 var.device(d) =
130 lr() * gradient_accum * var.constant(-1.0) /
131 (var.constant(l2()) *
132 var.constant(static_cast<float>(global_step) * lr()) +
133 gradient_squared_accum.sqrt());
134 }
135 }
136};
137
138template <typename T>
139struct ApplyAdagrad<CPUDevice, T> {
140 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
141 typename TTypes<T>::Flat accum,
142 typename TTypes<T>::ConstScalar lr,
143 typename TTypes<T>::ConstFlat grad, bool update_slots) {
144 if (update_slots) {
145 accum.device(d) += grad.square();
146 }
147 var.device(d) -= grad * lr() * accum.rsqrt();
148 }
149};
150
151template <typename T>
152struct ApplyAdagradV2<CPUDevice, T> {
153 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
154 typename TTypes<T>::Flat accum,
155 typename TTypes<T>::ConstScalar lr,
156 typename TTypes<T>::ConstScalar epsilon,
157 typename TTypes<T>::ConstFlat grad, bool update_slots) {
158 if (update_slots) {
159 accum.device(d) += grad.square();
160 }
161 var.device(d) -= grad * lr() / (accum.sqrt() + epsilon());
162 }
163};
164
165template <typename T, typename Tindex, bool has_epsilon>
166struct SparseApplyAdagrad<CPUDevice, T, Tindex, has_epsilon> {
167 Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var,
168 typename TTypes<T>::Matrix accum,
169 typename TTypes<T>::ConstScalar lr,
170 typename TTypes<T>::ConstScalar epsilon,
171 typename TTypes<T>::ConstMatrix grad,
172 typename TTypes<Tindex>::ConstVec indices,
173 int64_t inner_dim, bool update_slots) {
174 const Tindex N = static_cast<Tindex>(indices.dimension(0));
175 if (N == 0) return OkStatus();
176 const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
177 const T lr_scalar = lr();
178 const int in_bytes = inner_dim * sizeof(T) * 3;
179 const int out_bytes = inner_dim * sizeof(T) * 2;
180 const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 +
181 Eigen::TensorOpCost::MulCost<T>() * 2);
182 const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles);
183
184 if (inner_dim > 1) {
185 for (Tindex i = 0; i < N; ++i) {
186 const Tindex index = internal::SubtleMustCopy(indices(i));
187 if (!FastBoundsCheck(index, first_dim_size)) {
188 return errors::InvalidArgument(
189 strings::StrCat("Index ", index, " at offset ", i,
190 " in indices is out of range"));
191 }
192 }
193
194 const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
195 for (Tindex i = start_idx; i < end_idx; ++i) {
196 const Tindex index = internal::SubtleMustCopy(indices(i));
197 auto a = accum.template chip<0>(index);
198 auto g = grad.template chip<0>(i);
199 auto v = var.template chip<0>(index);
200 if (update_slots) {
201 a += g.square();
202 }
203 if (has_epsilon) {
204 v -= g.constant(lr_scalar) * g / (a.sqrt() + a.constant(epsilon()));
205 } else {
206 v -= g.constant(lr_scalar) * g * a.rsqrt();
207 }
208 }
209 };
210
211 d.parallelFor(N, cost, shard);
212 } else {
213 for (Tindex i = 0; i < N; ++i) {
214 const Tindex index = internal::SubtleMustCopy(indices(i));
215 if (!FastBoundsCheck(index, first_dim_size)) {
216 return errors::InvalidArgument(
217 strings::StrCat("Index ", index, " at offset ", i,
218 " in indices is out of range"));
219 }
220 }
221
222 const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void {
223 for (Tindex i = start_idx; i < end_idx; ++i) {
224 const Tindex index = internal::SubtleMustCopy(indices(i));
225 T& a = accum(index);
226 const T& g = grad(i);
227 if (update_slots) {
228 a += g * g;
229 }
230 if (has_epsilon) {
231 var(index) -= lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon());
232 } else {
233 var(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
234 }
235 }
236 };
237
238 d.parallelFor(N, cost, shard);
239 }
240
241 return OkStatus();
242 }
243};
244
245template <typename T>
246struct ApplyProximalAdagrad<CPUDevice, T> {
247 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
248 typename TTypes<T>::Flat accum,
249 typename TTypes<T>::ConstScalar lr,
250 typename TTypes<T>::ConstScalar l1,
251 typename TTypes<T>::ConstScalar l2,
252 typename TTypes<T>::ConstFlat grad) {
253 // Fobos update per paper with Adagrad learning rate.
254 accum.device(d) += grad.square();
255 // Adagrad learning rate.
256 auto learning_rate = accum.constant(lr()) * accum.rsqrt();
257 auto prox_var = var;
258 // compute v = w - lr * grad.
259 prox_var.device(d) -= grad * learning_rate;
260 if (l1() > 0) {
261 // compute sign(v) * max(|v| - lr * l1, 0)
262 var.device(d) = prox_var.sign() *
263 (prox_var.abs() - learning_rate * prox_var.constant(l1()))
264 .cwiseMax(T(0.0)) /
265 (var.constant(1.0) + var.constant(l2()) * learning_rate);
266 } else {
267 var.device(d) =
268 prox_var / (var.constant(1.0) + var.constant(l2()) * learning_rate);
269 }
270 }
271};
272
273template <typename T, typename Tindex>
274struct SparseApplyProximalAdagrad<CPUDevice, T, Tindex> {
275 Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var,
276 typename TTypes<T>::Matrix accum,
277 typename TTypes<T>::ConstScalar lr,
278 typename TTypes<T>::ConstScalar l1,
279 typename TTypes<T>::ConstScalar l2,
280 typename TTypes<T>::ConstMatrix grad,
281 typename TTypes<Tindex>::ConstVec indices,
282 int64_t inner_dim) {
283 const Tindex N = static_cast<Tindex>(indices.dimension(0));
284 if (N == 0) return OkStatus();
285 const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
286 const T lr_scalar = lr();
287 const T l1_scalar = l1();
288 const T l2_scalar = l2();
289 if (inner_dim > 1) {
290 for (Tindex i = 0; i < N; i++) {
291 const Tindex index = internal::SubtleMustCopy(indices(i));
292 if (!FastBoundsCheck(index, first_dim_size)) {
293 return errors::InvalidArgument(
294 strings::StrCat("Index ", index, " at offset ", i,
295 " in indices is out of range"));
296 }
297 auto a = accum.template chip<0>(index);
298 auto g = grad.template chip<0>(i);
299 auto v = var.template chip<0>(index);
300 a += g.square();
301 // compute learning_rate for current step.
302 auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
303 auto prox_v = v;
304 // v = w - g * learning_rate.
305 prox_v -= g * learning_rate;
306 if (l1_scalar > 0) {
307 // compute sign(v) * max(|v|, 0)
308 v = prox_v.sign() *
309 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
310 .cwiseMax(static_cast<T>(0.0)) /
311 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
312 } else {
313 v = prox_v /
314 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
315 }
316 }
317 } else {
318 for (Tindex i = 0; i < N; i++) {
319 const Tindex index = internal::SubtleMustCopy(indices(i));
320 if (!FastBoundsCheck(index, first_dim_size)) {
321 return errors::InvalidArgument(
322 strings::StrCat("Index ", index, " at offset ", i,
323 " in indices is out of range"));
324 }
325 T& a = accum(index);
326 const T& g = grad(i);
327 a += g * g;
328 auto learning_rate = lr_scalar / std::sqrt(a);
329 auto prox_v = var(index);
330 prox_v -= learning_rate * g;
331 if (l1_scalar > 0) {
332 var(index) = sgn(prox_v) *
333 std::max(std::abs(prox_v) - learning_rate * l1_scalar,
334 static_cast<T>(0.0)) /
335 (1.0 + l2_scalar * learning_rate);
336 } else {
337 var(index) = prox_v / (1.0 + l2_scalar * learning_rate);
338 }
339 }
340 }
341 return OkStatus();
342 }
343};
344
345template <typename T>
346struct ApplyFtrlV2<CPUDevice, T> {
347 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
348 typename TTypes<T>::Flat accum,
349 typename TTypes<T>::Flat linear,
350 typename TTypes<T>::ConstFlat grad,
351 typename TTypes<T>::ConstScalar lr,
352 typename TTypes<T>::ConstScalar l1,
353 typename TTypes<T>::ConstScalar l2,
354 typename TTypes<T>::ConstScalar l2_shrinkage,
355 typename TTypes<T>::ConstScalar lr_power) {
356 auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var;
357 auto new_accum = accum + grad * grad;
358 // special case for which lr_power=-0.5.
359 if (lr_power() == static_cast<T>(-0.5)) {
360 linear.device(d) +=
361 grad_with_shrinkage - (new_accum.sqrt() - accum.sqrt()) / lr() * var;
362 } else {
363 linear.device(d) +=
364 grad_with_shrinkage -
365 (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var;
366 }
367 auto x = (linear.constant(l1()) * linear.sign() - linear);
368 if (lr_power() == static_cast<T>(-0.5)) {
369 auto y = new_accum.sqrt() / new_accum.constant(lr()) +
370 linear.constant(static_cast<T>(2) * l2());
371 auto pre_shrink = x / y;
372 var.device(d) = (linear.abs() > linear.constant(l1()))
373 .select(pre_shrink, var.constant(static_cast<T>(0)));
374
375 } else {
376 auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) +
377 linear.constant(static_cast<T>(2) * l2());
378 auto pre_shrink = x / y;
379 var.device(d) = (linear.abs() > linear.constant(l1()))
380 .select(pre_shrink, var.constant(static_cast<T>(0)));
381 }
382 accum.device(d) += grad * grad;
383 }
384};
385
386template <typename T>
387struct ApplyFtrlV2MultiplyLinearByLr<CPUDevice, T> {
388 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
389 typename TTypes<T>::Flat accum,
390 typename TTypes<T>::Flat linear,
391 typename TTypes<T>::ConstFlat grad,
392 typename TTypes<T>::ConstScalar lr,
393 typename TTypes<T>::ConstScalar l1,
394 typename TTypes<T>::ConstScalar l2,
395 typename TTypes<T>::ConstScalar l2_shrinkage,
396 typename TTypes<T>::ConstScalar lr_power) {
397 auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var;
398 auto new_accum = accum + grad * grad;
399 // special case for which lr_power=-0.5.
400 if (lr_power() == static_cast<T>(-0.5)) {
401 linear.device(d) +=
402 grad_with_shrinkage * lr() - (new_accum.sqrt() - accum.sqrt()) * var;
403 } else {
404 linear.device(d) +=
405 grad_with_shrinkage * lr() -
406 (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) * var;
407 }
408 auto x = (linear.constant(l1() * lr()) * linear.sign() - linear);
409 if (lr_power() == static_cast<T>(-0.5)) {
410 auto y =
411 new_accum.sqrt() + linear.constant(static_cast<T>(2) * l2() * lr());
412 auto pre_shrink = x / y;
413 var.device(d) = (linear.abs() > linear.constant(l1() * lr()))
414 .select(pre_shrink, var.constant(static_cast<T>(0)));
415
416 } else {
417 auto y = new_accum.pow(-lr_power()) +
418 linear.constant(static_cast<T>(2) * l2() * lr());
419 auto pre_shrink = x / y;
420 var.device(d) = (linear.abs() > linear.constant(l1() * lr()))
421 .select(pre_shrink, var.constant(static_cast<T>(0)));
422 }
423 accum.device(d) += grad * grad;
424 }
425};
426
427template <typename T>
428struct ApplyFtrl<CPUDevice, T> {
429 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
430 typename TTypes<T>::Flat accum,
431 typename TTypes<T>::Flat linear,
432 typename TTypes<T>::ConstFlat grad,
433 typename TTypes<T>::ConstScalar lr,
434 typename TTypes<T>::ConstScalar l1,
435 typename TTypes<T>::ConstScalar l2,
436 typename TTypes<T>::ConstScalar lr_power) {
437 auto new_accum = accum + grad.square();
438 // special case for which lr_power=-0.5.
439 if (lr_power() == static_cast<T>(-0.5)) {
440 linear.device(d) += grad - (new_accum.sqrt() - accum.sqrt()) / lr() * var;
441 } else {
442 linear.device(d) +=
443 grad -
444 (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var;
445 }
446 auto x = (linear.constant(l1()) * linear.sign() - linear);
447 if (lr_power() == static_cast<T>(-0.5)) {
448 auto y = new_accum.sqrt() / new_accum.constant(lr()) +
449 linear.constant(static_cast<T>(2) * l2());
450 auto pre_shrink = x / y;
451 var.device(d) = (linear.abs() > linear.constant(l1()))
452 .select(pre_shrink, var.constant(static_cast<T>(0)));
453
454 } else {
455 auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) +
456 linear.constant(static_cast<T>(2) * l2());
457 auto pre_shrink = x / y;
458 var.device(d) = (linear.abs() > linear.constant(l1()))
459 .select(pre_shrink, var.constant(static_cast<T>(0)));
460 }
461 accum.device(d) += grad.square();
462 }
463};
464
465template <typename T>
466struct ApplyFtrlMultiplyLinearByLr<CPUDevice, T> {
467 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
468 typename TTypes<T>::Flat accum,
469 typename TTypes<T>::Flat linear,
470 typename TTypes<T>::ConstFlat grad,
471 typename TTypes<T>::ConstScalar lr,
472 typename TTypes<T>::ConstScalar l1,
473 typename TTypes<T>::ConstScalar l2,
474 typename TTypes<T>::ConstScalar lr_power) {
475 auto new_accum = accum + grad.square();
476 // special case for which lr_power=-0.5.
477 if (lr_power() == static_cast<T>(-0.5)) {
478 linear.device(d) += grad * lr() - (new_accum.sqrt() - accum.sqrt()) * var;
479 } else {
480 linear.device(d) +=
481 grad * lr() -
482 (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) * var;
483 }
484 auto x = (linear.constant(l1()) * lr() * linear.sign() - linear);
485 if (lr_power() == static_cast<T>(-0.5)) {
486 auto y =
487 new_accum.sqrt() + linear.constant(static_cast<T>(2) * l2() * lr());
488 auto pre_shrink = x / y;
489 var.device(d) = (linear.abs() > linear.constant(l1() * lr()))
490 .select(pre_shrink, var.constant(static_cast<T>(0)));
491
492 } else {
493 auto y = new_accum.pow(-lr_power()) +
494 linear.constant(static_cast<T>(2) * l2() * lr());
495 auto pre_shrink = x / y;
496 var.device(d) = (linear.abs() > linear.constant(l1() * lr()))
497 .select(pre_shrink, var.constant(static_cast<T>(0)));
498 }
499 accum.device(d) += grad.square();
500 }
501};
502
503namespace {
504
505template <typename T>
506inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1,
507 const T& l2, const T& lr_power,
508 const bool multiply_linear_by_lr) {
509 T quadratic;
510 if (multiply_linear_by_lr) {
511 if (lr_power == static_cast<T>(-0.5)) {
512 quadratic = Eigen::numext::sqrt(accum) + static_cast<T>(2) * l2 * lr;
513 } else {
514 quadratic =
515 Eigen::numext::pow(accum, -lr_power) + static_cast<T>(2) * l2 * lr;
516 }
517 auto l1_reg_adjust = std::max(std::min(linear, l1 * lr), -l1 * lr);
518 return (l1_reg_adjust - linear) / quadratic;
519 } else {
520 if (lr_power == static_cast<T>(-0.5)) {
521 quadratic = Eigen::numext::sqrt(accum) / lr + static_cast<T>(2) * l2;
522 } else {
523 quadratic =
524 Eigen::numext::pow(accum, -lr_power) / lr + static_cast<T>(2) * l2;
525 }
526 auto l1_reg_adjust = std::max(std::min(linear, l1), -l1);
527 return (l1_reg_adjust - linear) / quadratic;
528 }
529}
530
531template <typename T, typename GradTy, typename GradeMaybeWithShrinkageTy,
532 typename AccumTy, typename LinearTy, typename VarTy>
533void ComputeFtrl(GradTy grad,
534 GradeMaybeWithShrinkageTy grad_maybe_with_shrinkage,
535 AccumTy accum, LinearTy linear, VarTy var, T l1_scalar,
536 T l2_scalar, bool multiply_linear_by_lr, T lr_power_scalar,
537 T lr_scalar) {
538 auto new_accum = accum + grad.square();
539 if (multiply_linear_by_lr) {
540 if (lr_power_scalar == static_cast<T>(-0.5)) {
541 linear += grad_maybe_with_shrinkage * lr_scalar -
542 (new_accum.sqrt() - accum.sqrt()) * var;
543 } else {
544 linear +=
545 grad_maybe_with_shrinkage * lr_scalar -
546 (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * var;
547 }
548 } else {
549 if (lr_power_scalar == static_cast<T>(-0.5)) {
550 linear += grad_maybe_with_shrinkage -
551 (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var;
552 } else {
553 linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) -
554 accum.pow(-lr_power_scalar)) /
555 lr_scalar * var;
556 }
557 }
558 auto l1_reg_adjust =
559 (multiply_linear_by_lr ? linear.cwiseMin(l1_scalar * lr_scalar)
560 .cwiseMax(-l1_scalar * lr_scalar)
561 : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar));
562 auto x = l1_reg_adjust - linear;
563 if (multiply_linear_by_lr) {
564 if (lr_power_scalar == static_cast<T>(-0.5)) {
565 auto y = new_accum.sqrt() +
566 linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar);
567 var = x / y;
568 } else {
569 auto y = new_accum.pow(-lr_power_scalar) +
570 linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar);
571 var = x / y;
572 }
573 } else {
574 if (lr_power_scalar == static_cast<T>(-0.5)) {
575 auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) +
576 linear.constant(static_cast<T>(2) * l2_scalar);
577 var = x / y;
578 } else {
579 auto y = new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) +
580 linear.constant(static_cast<T>(2) * l2_scalar);
581 var = x / y;
582 }
583 }
584 accum += grad.square();
585}
586} // namespace
587
588template <typename T, typename Tindex, bool has_l2_shrinkage>
589struct SparseApplyFtrl<CPUDevice, T, Tindex, has_l2_shrinkage> {
590 Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var_flat,
591 typename TTypes<T>::Matrix accum_flat,
592 typename TTypes<T>::Matrix linear_flat,
593 typename TTypes<T>::ConstScalar lr,
594 typename TTypes<T>::ConstScalar l1,
595 typename TTypes<T>::ConstScalar l2,
596 typename TTypes<T>::ConstScalar l2_shrinkage,
597 typename TTypes<T>::ConstScalar lr_power,
598 typename TTypes<T>::ConstMatrix grad_flat,
599 typename TTypes<Tindex>::ConstVec indices_vec,
600 int64_t inner_dim, bool multiply_linear_by_lr) {
601 const Tindex N = static_cast<Tindex>(indices_vec.dimension(0));
602 if (N > 0) {
603 T lr_scalar = lr();
604 T l1_scalar = l1();
605 T l2_scalar = l2();
606 T l2_shrinkage_scalar;
607 if (has_l2_shrinkage) {
608 l2_shrinkage_scalar = l2_shrinkage();
609 }
610 T lr_power_scalar = lr_power();
611 if (inner_dim > 1) {
612 const Tindex first_dim_size =
613 static_cast<Tindex>(var_flat.dimension(0));
614
615 for (Tindex i = 0; i < N; i++) {
616 const Tindex index = internal::SubtleMustCopy(indices_vec(i));
617 if (!FastBoundsCheck(index, first_dim_size)) {
618 return errors::InvalidArgument(
619 strings::StrCat("Index ", index, " at offset ", i,
620 " in indices is out of range"));
621 }
622 auto accum = accum_flat.template chip<0>(index);
623 auto linear = linear_flat.template chip<0>(index);
624 auto grad = grad_flat.template chip<0>(i);
625 auto var = var_flat.template chip<0>(index);
626
627 if (has_l2_shrinkage) {
628 auto grad_with_shrinkage =
629 grad + static_cast<T>(2) * l2_shrinkage_scalar * var;
630 ComputeFtrl(/*grad=*/grad,
631 /*grad_maybe_with_shrinkage=*/grad_with_shrinkage,
632 /*accum=*/accum, /*linear=*/linear, /*var=*/var,
633 /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar,
634 /*multiply_linear_by_lr=*/multiply_linear_by_lr,
635 /*lr_power_scalar=*/lr_power_scalar,
636 /*lr_scalar=*/lr_scalar);
637 } else {
638 ComputeFtrl(/*grad=*/grad, /*grad_maybe_with_shrinkage=*/grad,
639 /*accum=*/accum, /*linear=*/linear, /*var=*/var,
640 /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar,
641 /*multiply_linear_by_lr=*/multiply_linear_by_lr,
642 /*lr_power_scalar=*/lr_power_scalar,
643 /*lr_scalar=*/lr_scalar);
644 }
645 }
646 } else {
647 const Tindex first_dim_size = accum_flat.size();
648
649 for (Tindex i = 0; i < N; i++) {
650 const Tindex index = internal::SubtleMustCopy(indices_vec(i));
651 if (!FastBoundsCheck(index, first_dim_size)) {
652 return errors::InvalidArgument(
653 strings::StrCat("Index ", index, " at offset ", i,
654 " in indices is out of range"));
655 }
656 T& a = accum_flat(index);
657 T& l = linear_flat(index);
658 T& v = var_flat(index);
659 T g;
660 if (has_l2_shrinkage) {
661 g = grad_flat(i) +
662 (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(index));
663 } else {
664 g = grad_flat(i);
665 }
666
667 T updated_a = a + grad_flat(i) * grad_flat(i);
668 using Eigen::numext::pow;
669 T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar);
670 if (!multiply_linear_by_lr) {
671 sigma /= lr_scalar;
672 }
673 T updated_l = (multiply_linear_by_lr ? l + g * lr_scalar - sigma * v
674 : l + g - sigma * v);
675 v = FtrlCompute(updated_a, updated_l, lr_scalar, l1_scalar, l2_scalar,
676 lr_power_scalar, multiply_linear_by_lr);
677 a = updated_a;
678 l = updated_l;
679 }
680 }
681 }
682 return OkStatus();
683 }
684};
685
686template <typename T>
687struct ApplyMomentum<CPUDevice, T> {
688 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
689 typename TTypes<T>::Flat accum,
690 typename TTypes<T>::ConstScalar lr,
691 typename TTypes<T>::ConstFlat grad,
692 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
693 accum.device(d) = accum * momentum() + grad;
694 if (use_nesterov) {
695 var.device(d) -= grad * lr() + accum * momentum() * lr();
696 } else {
697 var.device(d) -= accum * lr();
698 }
699 }
700};
701
702template <typename T>
703struct ApplyKerasMomentum<CPUDevice, T> {
704 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
705 typename TTypes<T>::Flat accum,
706 typename TTypes<T>::ConstScalar lr,
707 typename TTypes<T>::ConstFlat grad,
708 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
709 accum.device(d) = accum * momentum() - grad * lr();
710 if (use_nesterov) {
711 var.device(d) += (accum * momentum() - grad * lr());
712 } else {
713 var.device(d) += accum;
714 }
715 }
716};
717
718template <typename T, typename Tindex>
719struct SparseApplyKerasMomentum<CPUDevice, T, Tindex> {
720 Tindex operator()(const CPUDevice& d, typename TTypes<T>::Matrix var,
721 typename TTypes<T>::Matrix accum,
722 typename TTypes<T>::ConstScalar lr,
723 typename TTypes<T>::ConstMatrix grad,
724 typename TTypes<Tindex>::ConstFlat indices,
725 typename TTypes<T>::ConstScalar momentum,
726 bool use_nesterov) {
727 const Tindex N = static_cast<Tindex>(indices.size());
728 const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0));
729 for (Tindex i = 0; i < N; i++) {
730 const Tindex index = internal::SubtleMustCopy(indices(i));
731 if (!FastBoundsCheck(index, first_dim_size)) return i;
732 auto a = accum.template chip<0>(index);
733 auto g = grad.template chip<0>(i);
734 auto v = var.template chip<0>(index);
735 a = a * a.constant(momentum()) - g * g.constant(lr());
736 if (use_nesterov) {
737 v += a * a.constant(momentum()) - g * g.constant(lr());
738 } else {
739 v += a;
740 }
741 }
742 return -1;
743 }
744};
745
746template <typename T, typename Tindex>
747struct SparseApplyAdadelta<CPUDevice, T, Tindex> {
748 void operator()(const CPUDevice& d, typename TTypes<T>::Matrix var,
749 typename TTypes<T>::Matrix accum,
750 typename TTypes<T>::Matrix accum_update,
751 typename TTypes<T>::ConstScalar lr,
752 typename TTypes<T>::ConstScalar rho,
753 typename TTypes<T>::ConstScalar epsilon,
754 typename TTypes<T>::ConstMatrix grad,
755 typename TTypes<Tindex>::ConstFlat indices) {
756 const Tindex N = static_cast<Tindex>(indices.size());
757 for (Tindex i = 0; i < N; i++) {
758 const Tindex index = indices(i);
759 auto a = accum.template chip<0>(index);
760 auto a_update = accum_update.template chip<0>(index);
761 auto g = grad.template chip<0>(i);
762
763 a = a * a.constant(rho()) + g.square() * g.constant(T(1) - rho());
764 const auto update = (a_update + a_update.constant(epsilon())).sqrt() *
765 (a + a.constant(epsilon())).rsqrt() * g;
766 auto v = var.template chip<0>(index);
767 v -= update * update.constant(lr());
768 a_update = a_update * a_update.constant(rho()) +
769 update.square() * update.constant(static_cast<T>(1) - rho());
770 }
771 }
772};
773
774template <typename Device, typename T>
775struct ApplyAdamNonCuda {
776 void operator()(const Device& d, typename TTypes<T>::Flat var,
777 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
778 typename TTypes<T>::ConstScalar beta1_power,
779 typename TTypes<T>::ConstScalar beta2_power,
780 typename TTypes<T>::ConstScalar lr,
781 typename TTypes<T>::ConstScalar beta1,
782 typename TTypes<T>::ConstScalar beta2,
783 typename TTypes<T>::ConstScalar epsilon,
784 typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
785 // Get params length and check if they can be vectorized by packet size.
786 Index length = var.size();
787 Index packet_size = Eigen::internal::packet_traits<T>::size;
788 if (length % packet_size == 0) {
789 length = length / packet_size;
790 } else {
791 packet_size = 1;
792 }
793
794 T* var_ptr = var.data();
795 T* m_ptr = m.data();
796 T* v_ptr = v.data();
797 const T* g_ptr = grad.data();
798 const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) /
799 (T(1) - beta1_power());
800 // beta1 == μ
801 // beta2 == ν
802 // v == n
803 // var == θ
804
805 auto shard = [var_ptr, m_ptr, v_ptr, g_ptr, alpha, beta1, beta2, epsilon,
806 use_nesterov, packet_size](int begin, int end) {
807 int t_size = (end - begin) * packet_size;
808 begin = begin * packet_size;
809 auto var = typename TTypes<T>::UnalignedTensor(var_ptr + begin, t_size);
810 auto m = typename TTypes<T>::UnalignedTensor(m_ptr + begin, t_size);
811 auto v = typename TTypes<T>::UnalignedTensor(v_ptr + begin, t_size);
812 auto g = typename TTypes<T>::UnalignedConstTensor(g_ptr + begin, t_size);
813
814 if (use_nesterov) {
815 m += (g - m) * (T(1) - beta1());
816 v += (g.square() - v) * (T(1) - beta2());
817 var -= ((g * (T(1) - beta1()) + beta1() * m) * alpha) /
818 (v.sqrt() + epsilon());
819 } else {
820 m += (g - m) * (T(1) - beta1());
821 v += (g.square() - v) * (T(1) - beta2());
822 var -= (m * alpha) / (v.sqrt() + epsilon());
823 }
824 };
825
826 // Input data: var, v, m, grad.
827 // Output data: var, v, m.
828 const int input_bytes = length * packet_size * sizeof(T) * 4;
829 const int output_bytes = length * packet_size * sizeof(T) * 3;
830 const int compute_cycles =
831 // Consider Sub as Add
832 (Eigen::TensorOpCost::AddCost<int>() * 5 +
833 Eigen::TensorOpCost::MulCost<int>() * 2 +
834 Eigen::TensorOpCost::AddCost<T>() * 10 +
835 Eigen::TensorOpCost::MulCost<T>() * 6 +
836 Eigen::TensorOpCost::DivCost<T>()) *
837 length;
838 const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
839
840 // Eigen device must update 3 variables with 3 different expressions,
841 // which is bad for cache locality on CPU. Here use ParallelFor instead of
842 // "regular" tensor expressions to get better performance.
843 d.parallelFor(length, cost, shard);
844 }
845};
846
847template <typename T>
848struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {};
849
850template <typename T>
851struct ApplyAdamWithAmsgrad<CPUDevice, T> {
852 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
853 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
854 typename TTypes<T>::Flat vhat,
855 typename TTypes<T>::ConstScalar beta1_power,
856 typename TTypes<T>::ConstScalar beta2_power,
857 typename TTypes<T>::ConstScalar lr,
858 typename TTypes<T>::ConstScalar beta1,
859 typename TTypes<T>::ConstScalar beta2,
860 typename TTypes<T>::ConstScalar epsilon,
861 typename TTypes<T>::ConstFlat grad) {
862 const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) /
863 (T(1) - beta1_power());
864
865 m.device(d) += (grad - m) * (T(1) - beta1());
866 v.device(d) += (grad.square() - v) * (T(1) - beta2());
867 vhat.device(d) = vhat.cwiseMax(v);
868 var.device(d) -= (m * alpha) / (vhat.sqrt() + epsilon());
869 }
870};
871
872template <typename Device, typename T>
873struct ApplyAdaMaxNonCuda {
874 void operator()(const Device& d, typename TTypes<T>::Flat var,
875 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
876 typename TTypes<T>::ConstScalar beta1_power,
877 typename TTypes<T>::ConstScalar lr,
878 typename TTypes<T>::ConstScalar beta1,
879 typename TTypes<T>::ConstScalar beta2,
880 typename TTypes<T>::ConstScalar epsilon,
881 typename TTypes<T>::ConstFlat grad) {
882 m.device(d) += (grad - m) * (T(1) - beta1());
883 // Here v is u in section 7.1
884 v.device(d) = (beta2() * v).cwiseMax(grad.abs());
885 // var is θ in section 7.1
886 var.device(d) -= lr() / (T(1) - beta1_power()) * (m / (v + epsilon()));
887 }
888};
889
890template <typename T>
891struct ApplyAdaMax<CPUDevice, T> : ApplyAdaMaxNonCuda<CPUDevice, T> {};
892
893template <typename T>
894struct ApplyRMSProp<CPUDevice, T> {
895 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
896 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
897 typename TTypes<T>::ConstScalar lr,
898 typename TTypes<T>::ConstScalar rho,
899 typename TTypes<T>::ConstScalar momentum,
900 typename TTypes<T>::ConstScalar epsilon,
901 typename TTypes<T>::ConstFlat grad) {
902 ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho());
903 mom.device(d) =
904 mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt());
905 var.device(d) -= mom;
906 }
907};
908
909template <typename T>
910struct ApplyCenteredRMSProp<CPUDevice, T> {
911 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
912 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
913 typename TTypes<T>::Flat mom,
914 typename TTypes<T>::ConstScalar lr,
915 typename TTypes<T>::ConstScalar rho,
916 typename TTypes<T>::ConstScalar momentum,
917 typename TTypes<T>::ConstScalar epsilon,
918 typename TTypes<T>::ConstFlat grad) {
919 ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho());
920 mg.device(d) += (grad - mg) * (static_cast<T>(1) - rho());
921 auto denom = (ms - mg.square()) + epsilon();
922 mom.device(d) = mom * momentum() + (grad * lr()) / denom.sqrt();
923 var.device(d) -= mom;
924 }
925};
926
927template <typename T>
928struct ApplyAddSign<CPUDevice, T> {
929 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
930 typename TTypes<T>::Flat m,
931 typename TTypes<T>::ConstScalar lr,
932 typename TTypes<T>::ConstScalar alpha,
933 typename TTypes<T>::ConstScalar sign_decay,
934 typename TTypes<T>::ConstScalar beta,
935 typename TTypes<T>::ConstFlat grad) {
936 m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
937 auto sign_gm = grad.sign() * m.sign();
938 var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad;
939 }
940};
941
942template <typename T>
943struct ApplyPowerSign<CPUDevice, T> {
944 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
945 typename TTypes<T>::Flat m,
946 typename TTypes<T>::ConstScalar lr,
947 typename TTypes<T>::ConstScalar logbase,
948 typename TTypes<T>::ConstScalar sign_decay,
949 typename TTypes<T>::ConstScalar beta,
950 typename TTypes<T>::ConstFlat grad) {
951 m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
952 auto sign_gm = grad.sign() * m.sign();
953 auto grad_scale = (logbase() * sign_decay() * sign_gm).exp();
954 var.device(d) -= lr() * grad_scale * grad;
955 }
956};
957
958} // namespace functor
959
960template <typename Device, typename T>
961class ApplyGradientDescentOp : public OpKernel {
962 public:
963 explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
964 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
965 }
966
967 void Compute(OpKernelContext* ctx) override {
968 const bool sparse = false;
969 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
970 ctx, use_exclusive_lock_, sparse, {0});
971 Tensor var;
972 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
973 ctx, 0, use_exclusive_lock_, sparse, &var));
974
975 OP_REQUIRES(
976 ctx, var.IsInitialized(),
977 errors::FailedPrecondition(
978 "Attempting to use uninitialized variables: ", requested_input(0)));
979 const Tensor& alpha = ctx->input(1);
980 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
981 errors::InvalidArgument("alpha is not a scalar: ",
982 alpha.shape().DebugString()));
983 const Tensor& delta = ctx->input(2);
984 OP_REQUIRES(
985 ctx, var.shape().IsSameSize(delta.shape()),
986 errors::InvalidArgument("var and delta do not have the same shape",
987 var.shape().DebugString(), " ",
988 delta.shape().DebugString()));
989
990 const Device& device = ctx->template eigen_device<Device>();
991 functor::ApplyGradientDescent<Device, T>()(
992 device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>());
993
994 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
995 }
996
997 private:
998 bool use_exclusive_lock_;
999};
1000
1001#define REGISTER_KERNELS(D, T) \
1002 REGISTER_KERNEL_BUILDER( \
1003 Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1004 ApplyGradientDescentOp<D##Device, T>); \
1005 REGISTER_KERNEL_BUILDER(Name("ResourceApplyGradientDescent") \
1006 .Device(DEVICE_##D) \
1007 .HostMemory("var") \
1008 .TypeConstraint<T>("T"), \
1009 ApplyGradientDescentOp<D##Device, T>);
1010#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
1011
1012TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
1013TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
1014
1015#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1016// Forward declarations of the functor specializations for GPU.
1017namespace functor {
1018#define DECLARE_GPU_SPEC(T) \
1019 template <> \
1020 void ApplyGradientDescent<GPUDevice, T>::operator()( \
1021 const GPUDevice& d, typename TTypes<T>::Flat var, \
1022 typename TTypes<T>::ConstScalar alpha, \
1023 typename TTypes<T>::ConstFlat delta); \
1024 extern template struct ApplyGradientDescent<GPUDevice, T>;
1025DECLARE_GPU_SPEC(Eigen::half);
1026DECLARE_GPU_SPEC(float);
1027DECLARE_GPU_SPEC(double);
1028DECLARE_GPU_SPEC(complex64);
1029DECLARE_GPU_SPEC(complex128);
1030#undef DECLARE_GPU_SPEC
1031} // namespace functor
1032
1033REGISTER_KERNELS(GPU, Eigen::half);
1034REGISTER_KERNELS(GPU, float);
1035REGISTER_KERNELS(GPU, double);
1036REGISTER_KERNELS(GPU, complex64);
1037REGISTER_KERNELS(GPU, complex128);
1038#endif
1039
1040#undef REGISTER_CPU_KERNELS
1041#undef REGISTER_KERNELS
1042
1043template <typename Device, typename T>
1044class ApplyAdadeltaOp : public OpKernel {
1045 public:
1046 explicit ApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1047 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1048 }
1049
1050 void Compute(OpKernelContext* ctx) override {
1051 const bool sparse = false;
1052 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
1053 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
1054 DoValidate(ctx);
1055 if (!ctx->status().ok()) return;
1056 DoCompute(ctx);
1057 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1058 }
1059
1060 private:
1061 bool use_exclusive_lock_;
1062
1063 void DoValidate(OpKernelContext* ctx) {
1064 Tensor var;
1065 const bool sparse = false;
1066 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1067 ctx, 0, use_exclusive_lock_, sparse, &var));
1068 Tensor accum;
1069 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1070 ctx, 1, use_exclusive_lock_, sparse, &accum));
1071 Tensor accum_update;
1072 OP_REQUIRES_OK(
1073 ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_,
1074 sparse, &accum_update));
1075
1076 OP_REQUIRES(
1077 ctx, var.IsInitialized(),
1078 errors::FailedPrecondition(
1079 "Attempting to use uninitialized variables: ", requested_input(0)));
1080 OP_REQUIRES(
1081 ctx, accum.IsInitialized(),
1082 errors::FailedPrecondition(
1083 "Attempting to use uninitialized variables: ", requested_input(1)));
1084 OP_REQUIRES(
1085 ctx, accum_update.IsInitialized(),
1086 errors::FailedPrecondition(
1087 "Attempting to use uninitialized variables: ", requested_input(2)));
1088
1089 const Tensor& lr = ctx->input(3);
1090 const Tensor& rho = ctx->input(4);
1091 const Tensor& epsilon = ctx->input(5);
1092 const Tensor& grad = ctx->input(6);
1093
1094 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
1095 errors::InvalidArgument("lr is not a scalar: ",
1096 lr.shape().DebugString()));
1097
1098 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
1099 errors::InvalidArgument("rho is not a scalar: ",
1100 rho.shape().DebugString()));
1101
1102 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
1103 errors::InvalidArgument("epsilon is not a scalar: ",
1104 epsilon.shape().DebugString()));
1105
1106 OP_REQUIRES(
1107 ctx, var.shape().IsSameSize(accum.shape()),
1108 errors::InvalidArgument("var and accum do not have the same shape",
1109 var.shape().DebugString(), " ",
1110 accum.shape().DebugString()));
1111 OP_REQUIRES(
1112 ctx, var.shape().IsSameSize(grad.shape()),
1113 errors::InvalidArgument("var and grad do not have the same shape",
1114 var.shape().DebugString(), " ",
1115 grad.shape().DebugString()));
1116 }
1117
1118 void DoCompute(OpKernelContext* ctx) {
1119 const Device& device = ctx->template eigen_device<Device>();
1120 Tensor var;
1121 const bool sparse = false;
1122 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1123 ctx, 0, use_exclusive_lock_, sparse, &var));
1124 Tensor accum;
1125 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1126 ctx, 1, use_exclusive_lock_, sparse, &accum));
1127 Tensor accum_update;
1128 OP_REQUIRES_OK(
1129 ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_,
1130 sparse, &accum_update));
1131
1132 const Tensor& lr = ctx->input(3);
1133 const Tensor& rho = ctx->input(4);
1134 const Tensor& epsilon = ctx->input(5);
1135 const Tensor& grad = ctx->input(6);
1136
1137 functor::ApplyAdadelta<Device, T>()(
1138 device, var.flat<T>(), accum.flat<T>(), accum_update.flat<T>(),
1139 lr.scalar<T>(), rho.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>());
1140 }
1141};
1142
1143#define REGISTER_KERNELS(D, T) \
1144 REGISTER_KERNEL_BUILDER( \
1145 Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1146 ApplyAdadeltaOp<D##Device, T>); \
1147 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdadelta") \
1148 .Device(DEVICE_##D) \
1149 .HostMemory("var") \
1150 .HostMemory("accum") \
1151 .HostMemory("accum_update") \
1152 .TypeConstraint<T>("T"), \
1153 ApplyAdadeltaOp<D##Device, T>);
1154#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
1155
1156TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
1157TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
1158
1159#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1160// Forward declarations of the functor specializations for GPU.
1161namespace functor {
1162#define DECLARE_GPU_SPEC(T) \
1163 template <> \
1164 void ApplyAdadelta<GPUDevice, T>::operator()( \
1165 const GPUDevice& d, typename TTypes<T>::Flat var, \
1166 typename TTypes<T>::Flat accum, typename TTypes<T>::Flat accum_update, \
1167 typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \
1168 typename TTypes<T>::ConstScalar epsilon, \
1169 typename TTypes<T>::ConstFlat grad); \
1170 extern template struct ApplyAdadelta<GPUDevice, T>;
1171DECLARE_GPU_SPEC(Eigen::half);
1172DECLARE_GPU_SPEC(float);
1173DECLARE_GPU_SPEC(double);
1174DECLARE_GPU_SPEC(complex64);
1175DECLARE_GPU_SPEC(complex128);
1176#undef DECLARE_GPU_SPEC
1177} // namespace functor
1178
1179REGISTER_KERNELS(GPU, Eigen::half);
1180REGISTER_KERNELS(GPU, float);
1181REGISTER_KERNELS(GPU, double);
1182REGISTER_KERNELS(GPU, complex64);
1183REGISTER_KERNELS(GPU, complex128);
1184#endif
1185#undef REGISTER_CPU_KERNELS
1186#undef REGISTER_KERNELS
1187
1188template <typename T, typename Device, typename Tindex>
1189class SparseApplyAdadeltaOp : public OpKernel {
1190 public:
1191 explicit SparseApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1192 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1193 }
1194
1195 void Compute(OpKernelContext* ctx) override {
1196 const bool sparse = true;
1197 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
1198 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
1199 DoCompute(ctx);
1200 }
1201
1202 void DoCompute(OpKernelContext* ctx) {
1203 Tensor var;
1204 const bool sparse = true;
1205 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1206 ctx, 0, use_exclusive_lock_, sparse, &var));
1207 Tensor accum_grad;
1208 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1209 ctx, 1, use_exclusive_lock_, sparse, &accum_grad));
1210 Tensor accum_update;
1211 OP_REQUIRES_OK(
1212 ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_,
1213 sparse, &accum_update));
1214 OP_REQUIRES(
1215 ctx, var.IsInitialized(),
1216 errors::FailedPrecondition(
1217 "Attempting to use uninitialized variables: ", requested_input(0)));
1218 OP_REQUIRES(
1219 ctx, accum_grad.IsInitialized(),
1220 errors::FailedPrecondition(
1221 "Attempting to use uninitialized variables: ", requested_input(1)));
1222 OP_REQUIRES(
1223 ctx, accum_update.IsInitialized(),
1224 errors::FailedPrecondition(
1225 "Attempting to use uninitialized variables: ", requested_input(2)));
1226 OP_REQUIRES(
1227 ctx, var.shape().IsSameSize(accum_grad.shape()),
1228 errors::InvalidArgument("var and accum_grad do not have the same shape",
1229 var.shape().DebugString(), " ",
1230 accum_grad.shape().DebugString()));
1231 OP_REQUIRES(ctx, var.shape().IsSameSize(accum_update.shape()),
1232 errors::InvalidArgument(
1233 "var and accum_update do not have the same shape",
1234 var.shape().DebugString(), " ",
1235 accum_update.shape().DebugString()));
1236 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
1237 errors::InvalidArgument("var must be at least 1 dimensional"));
1238
1239 const Tensor& lr = ctx->input(3);
1240 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
1241 errors::InvalidArgument("lr is not a scalar: ",
1242 lr.shape().DebugString()));
1243 const Tensor& rho = ctx->input(4);
1244 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
1245 errors::InvalidArgument("rho is not a scalar: ",
1246 rho.shape().DebugString()));
1247 const Tensor& epsilon = ctx->input(5);
1248 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
1249 errors::InvalidArgument("epsilon is not a scalar: ",
1250 epsilon.shape().DebugString()));
1251 const Tensor& grad = ctx->input(6);
1252 const Tensor& indices = ctx->input(7);
1253 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
1254 errors::InvalidArgument("indices must be one-dimensional"));
1255
1256 for (int d = 1; d < var.dims(); d++) {
1257 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
1258 errors::InvalidArgument(strings::StrCat(
1259 "var and grad must match in dimension ", d)));
1260 }
1261 const Tindex N = indices.dim_size(0);
1262 OP_REQUIRES(
1263 ctx, grad.dim_size(0) == N,
1264 errors::InvalidArgument(
1265 "grad must be the same size as indices in the first dimension."));
1266
1267 if (N > 0) {
1268 const Tindex first_dim_size = var.dim_size(0);
1269 // Validate all the indices are in range
1270 auto indices_vec = indices.vec<Tindex>();
1271 for (Tindex i = 0; i < N; i++) {
1272 const Tindex index = indices_vec(i);
1273 OP_REQUIRES(ctx,
1274 (!std::is_same<Device, CPUDevice>::value ||
1275 (index >= 0 && index < first_dim_size)),
1276 errors::InvalidArgument(
1277 strings::StrCat("Index ", index, " at offset ", i,
1278 " in indices is out of range")));
1279 }
1280
1281 const Device& device = ctx->template eigen_device<Device>();
1282 functor::SparseApplyAdadelta<Device, T, Tindex>()(
1283 device, var.flat_outer_dims<T>(), accum_grad.flat_outer_dims<T>(),
1284 accum_update.flat_outer_dims<T>(), lr.scalar<T>(), rho.scalar<T>(),
1285 epsilon.scalar<T>(), grad.flat_outer_dims<T>(), indices_vec);
1286 }
1287
1288 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1289 }
1290
1291 private:
1292 bool use_exclusive_lock_;
1293};
1294
1295#define REGISTER_KERNELS(T, D, Tindices) \
1296 REGISTER_KERNEL_BUILDER(Name("SparseApplyAdadelta") \
1297 .Device(DEVICE_##D) \
1298 .TypeConstraint<T>("T") \
1299 .TypeConstraint<Tindices>("Tindices"), \
1300 SparseApplyAdadeltaOp<T, D##Device, Tindices>); \
1301 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdadelta") \
1302 .Device(DEVICE_##D) \
1303 .TypeConstraint<T>("T") \
1304 .TypeConstraint<Tindices>("Tindices"), \
1305 SparseApplyAdadeltaOp<T, D##Device, Tindices>);
1306#define REGISTER_CPU_KERNELS(T) \
1307 REGISTER_KERNELS(T, CPU, int32); \
1308 REGISTER_KERNELS(T, CPU, int64_t);
1309
1310TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
1311TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
1312
1313#undef REGISTER_CPU_KERNELS
1314
1315#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1316// Forward declarations of the functor specializations for GPU.
1317namespace functor {
1318#define DECLARE_GPU_SPEC(T, Tindex) \
1319 template <> \
1320 void SparseApplyAdadelta<GPUDevice, T, Tindex>::operator()( \
1321 const GPUDevice& d, typename TTypes<T>::Matrix var, \
1322 typename TTypes<T>::Matrix accum, \
1323 typename TTypes<T>::Matrix accum_update, \
1324 typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \
1325 typename TTypes<T>::ConstScalar epsilon, \
1326 typename TTypes<T>::ConstMatrix grad, \
1327 typename TTypes<Tindex>::ConstFlat indices); \
1328 extern template struct SparseApplyAdadelta<GPUDevice, T, Tindex>;
1329DECLARE_GPU_SPEC(Eigen::half, int32);
1330DECLARE_GPU_SPEC(Eigen::half, int64_t);
1331DECLARE_GPU_SPEC(float, int32);
1332DECLARE_GPU_SPEC(float, int64_t);
1333DECLARE_GPU_SPEC(double, int32);
1334DECLARE_GPU_SPEC(double, int64_t);
1335DECLARE_GPU_SPEC(complex64, int32);
1336DECLARE_GPU_SPEC(complex64, int64_t);
1337DECLARE_GPU_SPEC(complex128, int32);
1338DECLARE_GPU_SPEC(complex128, int64_t);
1339#undef DECLARE_GPU_SPEC
1340} // namespace functor
1341
1342#define REGISTER_GPU_KERNELS(T) \
1343 REGISTER_KERNELS(T, GPU, int32); \
1344 REGISTER_KERNELS(T, GPU, int64_t);
1345
1346REGISTER_GPU_KERNELS(Eigen::half);
1347REGISTER_GPU_KERNELS(float);
1348REGISTER_GPU_KERNELS(double);
1349REGISTER_GPU_KERNELS(complex64);
1350REGISTER_GPU_KERNELS(complex128);
1351#undef REGISTER_GPU_KERNELS
1352#endif
1353#undef REGISTER_KERNELS
1354
1355// Note, this op works on cpu only.
1356template <typename Device, typename T>
1357class ApplyProximalGradientDescentOp : public OpKernel {
1358 public:
1359 explicit ApplyProximalGradientDescentOp(OpKernelConstruction* ctx)
1360 : OpKernel(ctx) {
1361 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1362 }
1363
1364 void Compute(OpKernelContext* ctx) override {
1365 const bool sparse = false;
1366 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
1367 ctx, use_exclusive_lock_, sparse, {0});
1368 Tensor var;
1369 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1370 ctx, 0, use_exclusive_lock_, sparse, &var));
1371
1372 OP_REQUIRES(
1373 ctx, var.IsInitialized(),
1374 errors::FailedPrecondition(
1375 "Attempting to use uninitialized variables: ", requested_input(0)));
1376 const Tensor& alpha = ctx->input(1);
1377 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
1378 errors::InvalidArgument("alpha is not a scalar: ",
1379 alpha.shape().DebugString()));
1380 const Tensor& l1 = ctx->input(2);
1381 OP_REQUIRES(
1382 ctx, TensorShapeUtils::IsScalar(l1.shape()),
1383 errors::InvalidArgument("l1 regularization strength is not a scalar: ",
1384 l1.shape().DebugString()));
1385 const Tensor& l2 = ctx->input(3);
1386 OP_REQUIRES(
1387 ctx, TensorShapeUtils::IsScalar(l2.shape()),
1388 errors::InvalidArgument("l2 regularization strength is not a scalar: ",
1389 l2.shape().DebugString()));
1390
1391 const Tensor& delta = ctx->input(4);
1392 OP_REQUIRES(
1393 ctx, var.shape().IsSameSize(delta.shape()),
1394 errors::InvalidArgument("var and delta do not have the same shape",
1395 var.shape().DebugString(), " ",
1396 delta.shape().DebugString()));
1397
1398 const Device& device = ctx->template eigen_device<Device>();
1399 functor::ApplyProximalGradientDescent<Device, T>()(
1400 device, var.flat<T>(), alpha.scalar<T>(), l1.scalar<T>(),
1401 l2.scalar<T>(), delta.flat<T>());
1402
1403 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1404 }
1405
1406 private:
1407 bool use_exclusive_lock_;
1408};
1409
1410#define REGISTER_KERNELS(D, T) \
1411 REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \
1412 .Device(DEVICE_##D) \
1413 .TypeConstraint<T>("T"), \
1414 ApplyProximalGradientDescentOp<D##Device, T>); \
1415 REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalGradientDescent") \
1416 .HostMemory("var") \
1417 .Device(DEVICE_##D) \
1418 .TypeConstraint<T>("T"), \
1419 ApplyProximalGradientDescentOp<D##Device, T>);
1420
1421REGISTER_KERNELS(CPU, float);
1422REGISTER_KERNELS(CPU, double);
1423#undef REGISTER_KERNELS
1424
1425// Note, this op works on cpu only.
1426template <typename T, typename Tindex>
1427class SparseApplyProximalGradientDescentOp : public OpKernel {
1428 public:
1429 explicit SparseApplyProximalGradientDescentOp(OpKernelConstruction* ctx)
1430 : OpKernel(ctx) {
1431 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1432 }
1433
1434 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
1435 const bool sparse = true;
1436 auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
1437 ctx, use_exclusive_lock_, sparse, {0});
1438 Tensor var;
1439 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
1440 ctx, 0, use_exclusive_lock_, sparse, &var));
1441 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
1442 errors::InvalidArgument("var must be at least 1 dimensional"));
1443
1444 const Tensor& lr = ctx->input(1);
1445 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
1446 errors::InvalidArgument("lr is not a scalar: ",
1447 lr.shape().DebugString()));
1448 const Tensor& l1 = ctx->input(2);
1449 OP_REQUIRES(
1450 ctx, TensorShapeUtils::IsScalar(l1.shape()),
1451 errors::InvalidArgument("l1 regularization strength is not a scalar: ",
1452 l1.shape().DebugString()));
1453 const Tensor& l2 = ctx->input(3);
1454 OP_REQUIRES(
1455 ctx, TensorShapeUtils::IsScalar(l2.shape()),
1456 errors::InvalidArgument("l2 regularization strength is not a scalar: ",
1457 l2.shape().DebugString()));
1458
1459 const Tensor& grad = ctx->input(4);
1460 const Tensor& indices = ctx->input(5);
1461 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
1462 errors::InvalidArgument("indices must be one-dimensional"));
1463
1464 int64_t inner_dim = 1;
1465 for (int d = 1; d < var.dims(); d++) {
1466 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
1467 errors::InvalidArgument(strings::StrCat(
1468 "var and grad must match in dimension ", d)));
1469 inner_dim *= grad.dim_size(d);
1470 }
1471 const Tindex N = indices.dim_size(0);
1472 OP_REQUIRES(
1473 ctx, grad.dim_size(0) == N,
1474 errors::InvalidArgument(
1475 "grad must be the same size as indices in the first dimension."));
1476 OP_REQUIRES(ctx, inner_dim > 0,
1477 errors::InvalidArgument(
1478 "Inner dimension should be greater than zero."));
1479
1480 if (N > 0) {
1481 if (inner_dim > 1) {
1482 const Tindex first_dim_size = var.dim_size(0);
1483 auto indices_vec = indices.vec<Tindex>();
1484 auto var_flat = var.flat_outer_dims<T>();
1485 auto grad_flat = grad.flat_outer_dims<T>();
1486 T lr_scalar = lr.scalar<T>()();
1487 T l1_scalar = l1.scalar<T>()();
1488 T l2_scalar = l2.scalar<T>()();
1489
1490 // TODO(xbing): extract the common logic for the Fobos update.
1491 for (Tindex i = 0; i < N; i++) {
1492 const Tindex index = internal::SubtleMustCopy(indices_vec(i));
1493 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
1494 errors::InvalidArgument(
1495 strings::StrCat("Index ", index, " at offset ", i,
1496 " in indices is out of range")));
1497 auto g = grad_flat.template chip<0>(i);
1498 auto v = var_flat.template chip<0>(index);
1499 // compute learning_rate for current step.
1500 auto learning_rate = v.constant(lr_scalar);
1501 auto prox_v = v;
1502 // v = w - g * learning_rate.
1503 prox_v -= g * learning_rate;
1504 if (l1_scalar > 0) {
1505 // compute sign(v) * max(|v|, 0)
1506 v = prox_v.sign() *
1507 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
1508 .cwiseMax(static_cast<T>(0.0)) /
1509 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
1510 } else {
1511 v = prox_v /
1512 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
1513 }
1514 }
1515 } else {
1516 auto indices_vec = indices.vec<Tindex>();
1517 auto var_flat = var.flat<T>();
1518 auto grad_flat = grad.flat<T>();
1519 T lr_scalar = lr.scalar<T>()();
1520 T l1_scalar = l1.scalar<T>()();
1521 T l2_scalar = l2.scalar<T>()();
1522 const Tindex first_dim_size = var_flat.size();
1523
1524 for (Tindex i = 0; i < N; i++) {
1525 const Tindex index = internal::SubtleMustCopy(indices_vec(i));
1526 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
1527 errors::InvalidArgument(
1528 strings::StrCat("Index ", index, " at offset ", i,
1529 " in indices is out of range")));
1530 const T& g = grad_flat(i);
1531 auto learning_rate = lr_scalar;
1532 auto prox_v = var_flat(index);
1533 prox_v -= learning_rate * g;
1534 if (l1_scalar > 0) {
1535 var_flat(index) =
1536 sgn(prox_v) *
1537 std::max(std::abs(prox_v) - learning_rate * l1_scalar,
1538 static_cast<T>(0.0)) /
1539 (1.0 + l2_scalar * learning_rate);
1540 } else {
1541 var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate);
1542 }
1543 }
1544 }
1545 }
1546
1547 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1548 }
1549
1550 private:
1551 bool use_exclusive_lock_;
1552};
1553
1554#define REGISTER_KERNELS(T, Tindices) \
1555 REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \
1556 .Device(DEVICE_CPU) \
1557 .TypeConstraint<T>("T") \
1558 .TypeConstraint<Tindices>("Tindices"), \
1559 SparseApplyProximalGradientDescentOp<T, Tindices>); \
1560 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalGradientDescent") \
1561 .Device(DEVICE_CPU) \
1562 .TypeConstraint<T>("T") \
1563 .TypeConstraint<Tindices>("Tindices"), \
1564 SparseApplyProximalGradientDescentOp<T, Tindices>);
1565
1566REGISTER_KERNELS(float, int32);
1567REGISTER_KERNELS(float, int64_t);
1568REGISTER_KERNELS(double, int32);
1569REGISTER_KERNELS(double, int64_t);
1570#undef REGISTER_KERNELS
1571
1572template <typename Device, typename T>
1573class ApplyAdagradOp : public OpKernel {
1574 public:
1575 explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1576 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1577 OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
1578 }
1579
1580 void Compute(OpKernelContext* ctx) override {
1581 const bool sparse = false;
1582 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
1583 ctx, use_exclusive_lock_, sparse, {0, 1});
1584 Tensor var;
1585 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1586 ctx, 0, use_exclusive_lock_, sparse, &var));
1587 Tensor accum;
1588 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1589 ctx, 1, use_exclusive_lock_, sparse, &accum));
1590 OP_REQUIRES(
1591 ctx, var.IsInitialized(),
1592 errors::FailedPrecondition(
1593 "Attempting to use uninitialized variables: ", requested_input(0)));
1594 OP_REQUIRES(
1595 ctx, accum.IsInitialized(),
1596 errors::FailedPrecondition(
1597 "Attempting to use uninitialized variables: ", requested_input(1)));
1598 const Tensor& lr = ctx->input(2);
1599 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
1600 errors::InvalidArgument("lr is not a scalar: ",
1601 lr.shape().DebugString()));
1602 const Tensor& grad = ctx->input(3);
1603 OP_REQUIRES(
1604 ctx, var.shape().IsSameSize(accum.shape()),
1605 errors::InvalidArgument("var and accum do not have the same shape",
1606 var.shape().DebugString(), " ",
1607 accum.shape().DebugString()));
1608 OP_REQUIRES(
1609 ctx, var.shape().IsSameSize(grad.shape()),
1610 errors::InvalidArgument("var and grad do not have the same shape",
1611 var.shape().DebugString(), " ",
1612 grad.shape().DebugString()));
1613
1614 const Device& device = ctx->template eigen_device<Device>();
1615 functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
1616 lr.scalar<T>(), grad.flat<T>(),
1617 update_slots_);
1618
1619 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1620 }
1621
1622 private:
1623 bool use_exclusive_lock_;
1624 bool update_slots_;
1625};
1626
1627#define REGISTER_KERNELS(D, T) \
1628 REGISTER_KERNEL_BUILDER( \
1629 Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1630 ApplyAdagradOp<D##Device, T>); \
1631 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagrad") \
1632 .HostMemory("var") \
1633 .HostMemory("accum") \
1634 .Device(DEVICE_##D) \
1635 .TypeConstraint<T>("T"), \
1636 ApplyAdagradOp<D##Device, T>);
1637#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
1638
1639TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
1640TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
1641
1642#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1643// Forward declarations of the functor specializations for GPU.
1644namespace functor {
1645#define DECLARE_GPU_SPEC(T) \
1646 template <> \
1647 void ApplyAdagrad<GPUDevice, T>::operator()( \
1648 const GPUDevice& d, typename TTypes<T>::Flat var, \
1649 typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
1650 typename TTypes<T>::ConstFlat grad, bool update_slots); \
1651 extern template struct ApplyAdagrad<GPUDevice, T>;
1652DECLARE_GPU_SPEC(Eigen::half);
1653DECLARE_GPU_SPEC(float);
1654DECLARE_GPU_SPEC(double);
1655DECLARE_GPU_SPEC(complex64);
1656DECLARE_GPU_SPEC(complex128);
1657#undef DECLARE_GPU_SPEC
1658} // namespace functor
1659
1660REGISTER_KERNELS(GPU, Eigen::half);
1661REGISTER_KERNELS(GPU, float);
1662REGISTER_KERNELS(GPU, double);
1663REGISTER_KERNELS(GPU, complex64);
1664REGISTER_KERNELS(GPU, complex128);
1665#endif
1666#undef REGISTER_CPU_KERNELS
1667#undef REGISTER_KERNELS
1668
1669template <typename Device, typename T>
1670class ApplyAdagradV2Op : public OpKernel {
1671 public:
1672 explicit ApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
1673 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1674 OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
1675 }
1676
1677 void Compute(OpKernelContext* ctx) override {
1678 const bool sparse = false;
1679 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
1680 ctx, use_exclusive_lock_, sparse, {0, 1});
1681 Tensor var;
1682 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1683 ctx, 0, use_exclusive_lock_, sparse, &var));
1684 Tensor accum;
1685 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1686 ctx, 1, use_exclusive_lock_, sparse, &accum));
1687 OP_REQUIRES(
1688 ctx, var.IsInitialized(),
1689 errors::FailedPrecondition(
1690 "Attempting to use uninitialized variables: ", requested_input(0)));
1691 OP_REQUIRES(
1692 ctx, accum.IsInitialized(),
1693 errors::FailedPrecondition(
1694 "Attempting to use uninitialized variables: ", requested_input(1)));
1695 const Tensor& lr = ctx->input(2);
1696 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
1697 errors::InvalidArgument("lr is not a scalar: ",
1698 lr.shape().DebugString()));
1699 const Tensor& epsilon = ctx->input(3);
1700 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
1701 errors::InvalidArgument("epsilon is not a scalar: ",
1702 epsilon.shape().DebugString()));
1703 const Tensor& grad = ctx->input(4);
1704 OP_REQUIRES(
1705 ctx, var.shape().IsSameSize(accum.shape()),
1706 errors::InvalidArgument("var and accum do not have the same shape",
1707 var.shape().DebugString(), " ",
1708 accum.shape().DebugString()));
1709 OP_REQUIRES(
1710 ctx, var.shape().IsSameSize(grad.shape()),
1711 errors::InvalidArgument("var and grad do not have the same shape",
1712 var.shape().DebugString(), " ",
1713 grad.shape().DebugString()));
1714
1715 const Device& device = ctx->template eigen_device<Device>();
1716 functor::ApplyAdagradV2<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
1717 lr.scalar<T>(), epsilon.scalar<T>(),
1718 grad.flat<T>(), update_slots_);
1719
1720 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1721 }
1722
1723 private:
1724 bool use_exclusive_lock_;
1725 bool update_slots_;
1726};
1727
1728#define REGISTER_KERNELS(D, T) \
1729 REGISTER_KERNEL_BUILDER( \
1730 Name("ApplyAdagradV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1731 ApplyAdagradV2Op<D##Device, T>); \
1732 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradV2") \
1733 .HostMemory("var") \
1734 .HostMemory("accum") \
1735 .Device(DEVICE_##D) \
1736 .TypeConstraint<T>("T"), \
1737 ApplyAdagradV2Op<D##Device, T>);
1738#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
1739
1740TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
1741TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
1742
1743#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1744// Forward declarations of the functor specializations for GPU.
1745namespace functor {
1746#define DECLARE_GPU_SPEC(T) \
1747 template <> \
1748 void ApplyAdagradV2<GPUDevice, T>::operator()( \
1749 const GPUDevice& d, typename TTypes<T>::Flat var, \
1750 typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
1751 typename TTypes<T>::ConstScalar epsilon, \
1752 typename TTypes<T>::ConstFlat grad, bool update_slots); \
1753 extern template struct ApplyAdagradV2<GPUDevice, T>;
1754DECLARE_GPU_SPEC(Eigen::half);
1755DECLARE_GPU_SPEC(float);
1756DECLARE_GPU_SPEC(double);
1757DECLARE_GPU_SPEC(complex64);
1758DECLARE_GPU_SPEC(complex128);
1759#undef DECLARE_GPU_SPEC
1760} // namespace functor
1761
1762REGISTER_KERNELS(GPU, Eigen::half);
1763REGISTER_KERNELS(GPU, float);
1764REGISTER_KERNELS(GPU, double);
1765REGISTER_KERNELS(GPU, complex64);
1766REGISTER_KERNELS(GPU, complex128);
1767#endif
1768#undef REGISTER_CPU_KERNELS
1769#undef REGISTER_KERNELS
1770
1771template <typename Device, typename T>
1772class ApplyProximalAdagradOp : public OpKernel {
1773 public:
1774 explicit ApplyProximalAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1775 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1776 }
1777
1778 void Compute(OpKernelContext* ctx) override {
1779 const bool sparse = false;
1780 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
1781 ctx, use_exclusive_lock_, sparse, {0, 1});
1782 Tensor var;
1783 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1784 ctx, 0, use_exclusive_lock_, sparse, &var));
1785 Tensor accum;
1786 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1787 ctx, 1, use_exclusive_lock_, sparse, &accum));
1788 OP_REQUIRES(
1789 ctx, var.IsInitialized(),
1790 errors::FailedPrecondition(
1791 "Attempting to use uninitialized variables: ", requested_input(0)));
1792 OP_REQUIRES(
1793 ctx, accum.IsInitialized(),
1794 errors::FailedPrecondition(
1795 "Attempting to use uninitialized variables: ", requested_input(1)));
1796 OP_REQUIRES(
1797 ctx, var.shape().IsSameSize(accum.shape()),
1798 errors::InvalidArgument("var and accum do not have the same shape",
1799 var.shape().DebugString(), " ",
1800 accum.shape().DebugString()));
1801 const Tensor& lr = ctx->input(2);
1802 OP_REQUIRES(ctx,
1803 TensorShapeUtils::IsScalar(lr.shape()) &&
1804 (!std::is_same<Device, CPUDevice>::value ||
1805 lr.scalar<T>()() > static_cast<T>(0)),
1806 errors::InvalidArgument("lr is not a positive scalar: ",
1807 lr.shape().DebugString()));
1808 const Tensor& l1 = ctx->input(3);
1809 OP_REQUIRES(ctx,
1810 TensorShapeUtils::IsScalar(l1.shape()) &&
1811 (!std::is_same<Device, CPUDevice>::value ||
1812 l1.scalar<T>()() >= static_cast<T>(0)),
1813 errors::InvalidArgument("l1 regularization strength is not a "
1814 "non-negative scalar: ",
1815 l1.shape().DebugString()));
1816 const Tensor& l2 = ctx->input(4);
1817 OP_REQUIRES(ctx,
1818 TensorShapeUtils::IsScalar(l2.shape()) &&
1819 (!std::is_same<Device, CPUDevice>::value ||
1820 l2.scalar<T>()() >= static_cast<T>(0)),
1821 errors::InvalidArgument("l2 regularization strength is not a "
1822 "non-negative scalar: ",
1823 l2.shape().DebugString()));
1824 const Tensor& grad = ctx->input(5);
1825 OP_REQUIRES(
1826 ctx, var.shape().IsSameSize(grad.shape()),
1827 errors::InvalidArgument("var and grad do not have the same shape",
1828 var.shape().DebugString(), " ",
1829 grad.shape().DebugString()));
1830
1831 const Device& device = ctx->template eigen_device<Device>();
1832 functor::ApplyProximalAdagrad<Device, T>()(
1833 device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), l1.scalar<T>(),
1834 l2.scalar<T>(), grad.flat<T>());
1835
1836 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1837 }
1838
1839 private:
1840 bool use_exclusive_lock_;
1841};
1842
1843#define REGISTER_KERNELS(D, T) \
1844 REGISTER_KERNEL_BUILDER( \
1845 Name("ApplyProximalAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
1846 ApplyProximalAdagradOp<D##Device, T>); \
1847 REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalAdagrad") \
1848 .Device(DEVICE_##D) \
1849 .HostMemory("var") \
1850 .HostMemory("accum") \
1851 .TypeConstraint<T>("T"), \
1852 ApplyProximalAdagradOp<D##Device, T>);
1853
1854REGISTER_KERNELS(CPU, float);
1855REGISTER_KERNELS(CPU, double);
1856
1857#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1858// Forward declarations of the functor specializations for GPU.
1859namespace functor {
1860#define DECLARE_GPU_SPEC(T) \
1861 template <> \
1862 void ApplyProximalAdagrad<GPUDevice, T>::operator()( \
1863 const GPUDevice& d, typename TTypes<T>::Flat var, \
1864 typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
1865 typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
1866 typename TTypes<T>::ConstFlat grad); \
1867 extern template struct ApplyProximalAdagrad<GPUDevice, T>;
1868DECLARE_GPU_SPEC(Eigen::half);
1869DECLARE_GPU_SPEC(float);
1870DECLARE_GPU_SPEC(double);
1871#undef DECLARE_GPU_SPEC
1872} // namespace functor
1873
1874REGISTER_KERNELS(GPU, Eigen::half);
1875REGISTER_KERNELS(GPU, float);
1876REGISTER_KERNELS(GPU, double);
1877#endif
1878#undef REGISTER_CPU_KERNELS
1879#undef REGISTER_KERNELS
1880
1881template <typename Device, typename T, typename Tindex>
1882class SparseApplyAdagradOp : public OpKernel {
1883 public:
1884 explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1885 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
1886 OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
1887 }
1888
1889 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
1890 const bool sparse = true;
1891 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
1892 ctx, use_exclusive_lock_, sparse, {0, 1});
1893 Tensor var;
1894 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1895 ctx, 0, use_exclusive_lock_, sparse, &var));
1896 Tensor accum;
1897 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
1898 ctx, 1, use_exclusive_lock_, sparse, &accum));
1899 OP_REQUIRES(
1900 ctx, var.IsInitialized(),
1901 errors::FailedPrecondition(
1902 "Attempting to use uninitialized variables: ", requested_input(0)));
1903 OP_REQUIRES(
1904 ctx, accum.IsInitialized(),
1905 errors::FailedPrecondition(
1906 "Attempting to use uninitialized variables: ", requested_input(1)));
1907 OP_REQUIRES(
1908 ctx, var.shape().IsSameSize(accum.shape()),
1909 errors::InvalidArgument("var and accum do not have the same shape",
1910 var.shape().DebugString(), " ",
1911 accum.shape().DebugString()));
1912 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
1913 errors::InvalidArgument("var must be at least 1 dimensional"));
1914
1915 const Tensor& lr = ctx->input(2);
1916 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
1917 errors::InvalidArgument("lr is not a scalar: ",
1918 lr.shape().DebugString()));
1919 const Tensor& grad = ctx->input(3);
1920 const Tensor& indices = ctx->input(4);
1921 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
1922 errors::InvalidArgument("indices must be one-dimensional"));
1923
1924 int64_t inner_dim = 1;
1925 for (int d = 1; d < var.dims(); d++) {
1926 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
1927 errors::InvalidArgument(strings::StrCat(
1928 "var and grad must match in dimension ", d)));
1929 inner_dim *= grad.dim_size(d);
1930 }
1931 const Tindex N = indices.dim_size(0);
1932 OP_REQUIRES(
1933 ctx, grad.dim_size(0) == N,
1934 errors::InvalidArgument(
1935 "grad must be the same size as indices in the first dimension."));
1936
1937 OP_REQUIRES(ctx, inner_dim > 0,
1938 errors::InvalidArgument(
1939 "Inner dimension should be greater than zero."));
1940
1941 const Device& device = ctx->template eigen_device<Device>();
1942 OP_REQUIRES_OK(
1943 ctx, functor::SparseApplyAdagrad<Device, T, Tindex,
1944 /*has_epsilon = */ false>()(
1945 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
1946 // Note: Passing lr as a placeholder for unused epsilon.
1947 lr.scalar<T>(), lr.scalar<T>(), grad.flat_outer_dims<T>(),
1948 indices.vec<Tindex>(), inner_dim, update_slots_));
1949
1950 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
1951 }
1952
1953 private:
1954 bool use_exclusive_lock_;
1955 bool update_slots_;
1956};
1957
1958#define REGISTER_KERNELS(D, T, Tindices) \
1959 REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad") \
1960 .Device(DEVICE_##D) \
1961 .TypeConstraint<T>("T") \
1962 .TypeConstraint<Tindices>("Tindices"), \
1963 SparseApplyAdagradOp<D##Device, T, Tindices>); \
1964 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad") \
1965 .Device(DEVICE_##D) \
1966 .TypeConstraint<T>("T") \
1967 .TypeConstraint<Tindices>("Tindices"), \
1968 SparseApplyAdagradOp<D##Device, T, Tindices>);
1969#define REGISTER_CPU_KERNELS(T) \
1970 REGISTER_KERNELS(CPU, T, int32); \
1971 REGISTER_KERNELS(CPU, T, int64_t);
1972
1973TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
1974TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
1975
1976#undef REGISTER_CPU_KERNELS
1977
1978#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1979// Forward declarations of the functor specializations for GPU.
1980namespace functor {
1981#define DECLARE_GPU_SPEC(T, Tindex) \
1982 template <> \
1983 Status \
1984 SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/false>::operator()( \
1985 const GPUDevice& d, typename TTypes<T>::Matrix var, \
1986 typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
1987 typename TTypes<T>::ConstScalar epsilon, \
1988 typename TTypes<T>::ConstMatrix grad, \
1989 typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \
1990 bool update_slots); \
1991 extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
1992 /*has_epsilon=*/false>;
1993DECLARE_GPU_SPEC(Eigen::half, int32);
1994DECLARE_GPU_SPEC(Eigen::half, int64_t);
1995DECLARE_GPU_SPEC(float, int32);
1996DECLARE_GPU_SPEC(float, int64_t);
1997DECLARE_GPU_SPEC(double, int32);
1998DECLARE_GPU_SPEC(double, int64_t);
1999#undef DECLARE_GPU_SPEC
2000} // namespace functor
2001
2002REGISTER_KERNELS(GPU, Eigen::half, int32);
2003REGISTER_KERNELS(GPU, Eigen::half, int64_t);
2004REGISTER_KERNELS(GPU, float, int32);
2005REGISTER_KERNELS(GPU, float, int64_t);
2006REGISTER_KERNELS(GPU, double, int32);
2007REGISTER_KERNELS(GPU, double, int64_t);
2008#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2009#undef REGISTER_KERNELS
2010
2011template <typename Device, typename T, typename Tindex>
2012class SparseApplyAdagradV2Op : public OpKernel {
2013 public:
2014 explicit SparseApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
2015 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
2016 OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
2017 }
2018
2019 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
2020 const bool sparse = true;
2021 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
2022 ctx, use_exclusive_lock_, sparse, {0, 1});
2023 Tensor var;
2024 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2025 ctx, 0, use_exclusive_lock_, sparse, &var));
2026 Tensor accum;
2027 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2028 ctx, 1, use_exclusive_lock_, sparse, &accum));
2029 OP_REQUIRES(
2030 ctx, var.IsInitialized(),
2031 errors::FailedPrecondition(
2032 "Attempting to use uninitialized variables: ", requested_input(0)));
2033 OP_REQUIRES(
2034 ctx, accum.IsInitialized(),
2035 errors::FailedPrecondition(
2036 "Attempting to use uninitialized variables: ", requested_input(1)));
2037 OP_REQUIRES(
2038 ctx, var.shape().IsSameSize(accum.shape()),
2039 errors::InvalidArgument("var and accum do not have the same shape",
2040 var.shape().DebugString(), " ",
2041 accum.shape().DebugString()));
2042 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
2043 errors::InvalidArgument("var must be at least 1 dimensional"));
2044
2045 const Tensor& lr = ctx->input(2);
2046 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
2047 errors::InvalidArgument("lr is not a scalar: ",
2048 lr.shape().DebugString()));
2049 const Tensor& epsilon = ctx->input(3);
2050 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
2051 errors::InvalidArgument("epsilon is not a scalar: ",
2052 epsilon.shape().DebugString()));
2053 const Tensor& grad = ctx->input(4);
2054 const Tensor& indices = ctx->input(5);
2055 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
2056 errors::InvalidArgument("indices must be one-dimensional"));
2057
2058 int64_t inner_dim = 1;
2059 for (int d = 1; d < var.dims(); d++) {
2060 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
2061 errors::InvalidArgument(strings::StrCat(
2062 "var and grad must match in dimension ", d)));
2063 inner_dim *= grad.dim_size(d);
2064 }
2065 const Tindex N = indices.dim_size(0);
2066 OP_REQUIRES(
2067 ctx, grad.dim_size(0) == N,
2068 errors::InvalidArgument(
2069 "grad must be the same size as indices in the first dimension."));
2070
2071 OP_REQUIRES(ctx, inner_dim > 0,
2072 errors::InvalidArgument(
2073 "Inner dimension should be greater than zero."));
2074
2075 const Device& device = ctx->template eigen_device<Device>();
2076 OP_REQUIRES_OK(
2077 ctx, functor::SparseApplyAdagrad<Device, T, Tindex,
2078 /*has_epsilon = */ true>()(
2079 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
2080 lr.scalar<T>(), epsilon.scalar<T>(), grad.flat_outer_dims<T>(),
2081 indices.vec<Tindex>(), inner_dim, update_slots_));
2082
2083 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
2084 }
2085
2086 private:
2087 bool use_exclusive_lock_;
2088 bool update_slots_;
2089};
2090
2091#define REGISTER_KERNELS(D, T, Tindices) \
2092 REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradV2") \
2093 .Device(DEVICE_##D) \
2094 .TypeConstraint<T>("T") \
2095 .TypeConstraint<Tindices>("Tindices"), \
2096 SparseApplyAdagradV2Op<D##Device, T, Tindices>); \
2097 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradV2") \
2098 .Device(DEVICE_##D) \
2099 .TypeConstraint<T>("T") \
2100 .TypeConstraint<Tindices>("Tindices"), \
2101 SparseApplyAdagradV2Op<D##Device, T, Tindices>);
2102#define REGISTER_CPU_KERNELS(T) \
2103 REGISTER_KERNELS(CPU, T, int32); \
2104 REGISTER_KERNELS(CPU, T, int64_t);
2105
2106TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
2107TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
2108
2109#undef REGISTER_CPU_KERNELS
2110
2111#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2112// Forward declarations of the functor specializations for GPU.
2113namespace functor {
2114#define DECLARE_GPU_SPEC(T, Tindex) \
2115 template <> \
2116 Status \
2117 SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/true>::operator()( \
2118 const GPUDevice& d, typename TTypes<T>::Matrix var, \
2119 typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
2120 typename TTypes<T>::ConstScalar epsilon, \
2121 typename TTypes<T>::ConstMatrix grad, \
2122 typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \
2123 bool update_slots); \
2124 extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \
2125 /*has_epsilon=*/true>;
2126DECLARE_GPU_SPEC(Eigen::half, int32);
2127DECLARE_GPU_SPEC(Eigen::half, int64_t);
2128DECLARE_GPU_SPEC(float, int32);
2129DECLARE_GPU_SPEC(float, int64_t);
2130DECLARE_GPU_SPEC(double, int32);
2131DECLARE_GPU_SPEC(double, int64_t);
2132#undef DECLARE_GPU_SPEC
2133} // namespace functor
2134
2135REGISTER_KERNELS(GPU, Eigen::half, int32);
2136REGISTER_KERNELS(GPU, Eigen::half, int64_t);
2137REGISTER_KERNELS(GPU, float, int32);
2138REGISTER_KERNELS(GPU, float, int64_t);
2139REGISTER_KERNELS(GPU, double, int32);
2140REGISTER_KERNELS(GPU, double, int64_t);
2141#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2142#undef REGISTER_KERNELS
2143
2144template <typename Device, typename T, typename Tindex>
2145class SparseApplyProximalAdagradOp : public OpKernel {
2146 public:
2147 explicit SparseApplyProximalAdagradOp(OpKernelConstruction* ctx)
2148 : OpKernel(ctx) {
2149 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
2150 }
2151
2152 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
2153 const bool sparse = true;
2154 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
2155 ctx, use_exclusive_lock_, sparse, {0, 1});
2156 Tensor var;
2157 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2158 ctx, 0, use_exclusive_lock_, sparse, &var));
2159 Tensor accum;
2160 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2161 ctx, 1, use_exclusive_lock_, sparse, &accum));
2162 OP_REQUIRES(
2163 ctx, var.IsInitialized(),
2164 errors::FailedPrecondition(
2165 "Attempting to use uninitialized variables: ", requested_input(0)));
2166 OP_REQUIRES(
2167 ctx, accum.IsInitialized(),
2168 errors::FailedPrecondition(
2169 "Attempting to use uninitialized variables: ", requested_input(1)));
2170 OP_REQUIRES(
2171 ctx, var.shape().IsSameSize(accum.shape()),
2172 errors::InvalidArgument("var and accum do not have the same shape",
2173 var.shape().DebugString(), " ",
2174 accum.shape().DebugString()));
2175 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
2176 errors::InvalidArgument("var must be at least 1 dimensional"));
2177
2178 const Tensor& lr = ctx->input(2);
2179 OP_REQUIRES(ctx,
2180 TensorShapeUtils::IsScalar(lr.shape()) &&
2181 (!std::is_same<Device, CPUDevice>::value ||
2182 lr.scalar<T>()() > static_cast<T>(0)),
2183 errors::InvalidArgument("lr is not a positive scalar: ",
2184 lr.shape().DebugString()));
2185 const Tensor& l1 = ctx->input(3);
2186 OP_REQUIRES(ctx,
2187 TensorShapeUtils::IsScalar(l1.shape()) &&
2188 (!std::is_same<Device, CPUDevice>::value ||
2189 l1.scalar<T>()() >= static_cast<T>(0)),
2190 errors::InvalidArgument("l1 regularization strength is not a "
2191 "non-negative scalar: ",
2192 l1.shape().DebugString()));
2193 const Tensor& l2 = ctx->input(4);
2194 OP_REQUIRES(ctx,
2195 TensorShapeUtils::IsScalar(l2.shape()) &&
2196 (!std::is_same<Device, CPUDevice>::value ||
2197 l2.scalar<T>()() >= static_cast<T>(0)),
2198 errors::InvalidArgument("l2 regularization strength is not a "
2199 "non-negative scalar: ",
2200 l2.shape().DebugString()));
2201
2202 const Tensor& grad = ctx->input(5);
2203 const Tensor& indices = ctx->input(6);
2204 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
2205 errors::InvalidArgument("indices must be one-dimensional"));
2206
2207 int64_t inner_dim = 1;
2208 for (int d = 1; d < var.dims(); d++) {
2209 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
2210 errors::InvalidArgument(strings::StrCat(
2211 "var and grad must match in dimension ", d)));
2212 inner_dim *= grad.dim_size(d);
2213 }
2214 const Tindex N = indices.dim_size(0);
2215 OP_REQUIRES(
2216 ctx, grad.dim_size(0) == N,
2217 errors::InvalidArgument(
2218 "grad must be the same size as indices in the first dimension."));
2219
2220 OP_REQUIRES(ctx, inner_dim > 0,
2221 errors::InvalidArgument(
2222 "Inner dimension should be greater than zero."));
2223
2224 const Device& device = ctx->template eigen_device<Device>();
2225 OP_REQUIRES_OK(
2226 ctx, functor::SparseApplyProximalAdagrad<Device, T, Tindex>()(
2227 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
2228 lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(),
2229 grad.flat_outer_dims<T>(), indices.vec<Tindex>(), inner_dim));
2230
2231 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
2232 }
2233
2234 private:
2235 bool use_exclusive_lock_;
2236};
2237
2238#define REGISTER_KERNELS(D, T, Tindices) \
2239 REGISTER_KERNEL_BUILDER( \
2240 Name("SparseApplyProximalAdagrad") \
2241 .Device(DEVICE_##D) \
2242 .TypeConstraint<T>("T") \
2243 .TypeConstraint<Tindices>("Tindices"), \
2244 SparseApplyProximalAdagradOp<D##Device, T, Tindices>); \
2245 REGISTER_KERNEL_BUILDER( \
2246 Name("ResourceSparseApplyProximalAdagrad") \
2247 .Device(DEVICE_##D) \
2248 .TypeConstraint<T>("T") \
2249 .TypeConstraint<Tindices>("Tindices"), \
2250 SparseApplyProximalAdagradOp<D##Device, T, Tindices>);
2251
2252REGISTER_KERNELS(CPU, float, int32);
2253REGISTER_KERNELS(CPU, float, int64_t);
2254REGISTER_KERNELS(CPU, double, int32);
2255REGISTER_KERNELS(CPU, double, int64_t);
2256
2257#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2258// Forward declarations of the functor specializations for GPU.
2259namespace functor {
2260#define DECLARE_GPU_SPEC(T, Tindex) \
2261 template <> \
2262 Status SparseApplyProximalAdagrad<GPUDevice, T, Tindex>::operator()( \
2263 const GPUDevice& d, typename TTypes<T>::Matrix var, \
2264 typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
2265 typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
2266 typename TTypes<T>::ConstMatrix grad, \
2267 typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim); \
2268 extern template struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex>;
2269DECLARE_GPU_SPEC(Eigen::half, int32);
2270DECLARE_GPU_SPEC(Eigen::half, int64_t);
2271DECLARE_GPU_SPEC(float, int32);
2272DECLARE_GPU_SPEC(float, int64_t);
2273DECLARE_GPU_SPEC(double, int32);
2274DECLARE_GPU_SPEC(double, int64_t);
2275#undef DECLARE_GPU_SPEC
2276} // namespace functor
2277
2278REGISTER_KERNELS(GPU, Eigen::half, int32);
2279REGISTER_KERNELS(GPU, Eigen::half, int64_t);
2280REGISTER_KERNELS(GPU, float, int32);
2281REGISTER_KERNELS(GPU, float, int64_t);
2282REGISTER_KERNELS(GPU, double, int32);
2283REGISTER_KERNELS(GPU, double, int64_t);
2284#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2285#undef REGISTER_KERNELS
2286
2287template <typename Device, typename T>
2288class ApplyAdagradDAOp : public OpKernel {
2289 public:
2290 explicit ApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
2291 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
2292 }
2293
2294 void Compute(OpKernelContext* ctx) override {
2295 const bool sparse = false;
2296 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
2297 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
2298 Tensor var;
2299 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2300 ctx, 0, use_exclusive_lock_, sparse, &var));
2301 Tensor gradient_accum;
2302 OP_REQUIRES_OK(
2303 ctx, GetInputTensorFromVariable<Device, T>(ctx, 1, use_exclusive_lock_,
2304 sparse, &gradient_accum));
2305 Tensor gradient_squared_accum;
2306 OP_REQUIRES_OK(
2307 ctx, GetInputTensorFromVariable<Device, T>(
2308 ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum));
2309 OP_REQUIRES(
2310 ctx, var.IsInitialized(),
2311 errors::FailedPrecondition(
2312 "Attempting to use uninitialized variables: ", requested_input(0)));
2313 OP_REQUIRES(
2314 ctx, gradient_accum.IsInitialized(),
2315 errors::FailedPrecondition(
2316 "Attempting to use uninitialized variables: ", requested_input(1)));
2317 OP_REQUIRES(
2318 ctx, gradient_squared_accum.IsInitialized(),
2319 errors::FailedPrecondition(
2320 "Attempting to use uninitialized variables: ", requested_input(2)));
2321 OP_REQUIRES(
2322 ctx, var.shape().IsSameSize(gradient_accum.shape()),
2323 errors::InvalidArgument("var and accum do not have the same shape",
2324 var.shape().DebugString(), " ",
2325 gradient_accum.shape().DebugString()));
2326 OP_REQUIRES(
2327 ctx, var.shape().IsSameSize(gradient_squared_accum.shape()),
2328 errors::InvalidArgument("var and accum do not have the same shape",
2329 var.shape().DebugString(), " ",
2330 gradient_squared_accum.shape().DebugString()));
2331
2332 const Tensor& grad = ctx->input(3);
2333 OP_REQUIRES(
2334 ctx, var.shape().IsSameSize(grad.shape()),
2335 errors::InvalidArgument("var and grad do not have the same shape",
2336 var.shape().DebugString(), " ",
2337 grad.shape().DebugString()));
2338
2339 const Tensor& lr = ctx->input(4);
2340 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
2341 errors::InvalidArgument("lr is not a scalar: ",
2342 lr.shape().DebugString()));
2343 const Tensor& l1 = ctx->input(5);
2344 OP_REQUIRES(
2345 ctx, TensorShapeUtils::IsScalar(l1.shape()),
2346 errors::InvalidArgument("l1 regularization strength is not a scalar: ",
2347 l1.shape().DebugString()));
2348 const Tensor& l2 = ctx->input(6);
2349 OP_REQUIRES(
2350 ctx, TensorShapeUtils::IsScalar(l2.shape()),
2351 errors::InvalidArgument("l2 regularization strength is not a scalar: ",
2352 l2.shape().DebugString()));
2353 const Tensor& global_step = ctx->input(7);
2354 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()),
2355 errors::InvalidArgument("global_step is not a scalar: ",
2356 global_step.shape().DebugString()));
2357
2358 const Device& device = ctx->template eigen_device<Device>();
2359 functor::ApplyAdagradDA<Device, T>()(
2360 device, var.flat<T>(), gradient_accum.flat<T>(),
2361 gradient_squared_accum.flat<T>(), lr.scalar<T>(),
2362 global_step.scalar<int64_t>()(), l1.scalar<T>(), l2.scalar<T>(),
2363 grad.flat<T>());
2364
2365 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
2366 }
2367
2368 private:
2369 bool use_exclusive_lock_;
2370};
2371
2372#define REGISTER_KERNELS(D, T) \
2373 REGISTER_KERNEL_BUILDER( \
2374 Name("ApplyAdagradDA").Device(DEVICE_##D).TypeConstraint<T>("T"), \
2375 ApplyAdagradDAOp<D##Device, T>); \
2376 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradDA") \
2377 .Device(DEVICE_##D) \
2378 .HostMemory("var") \
2379 .HostMemory("gradient_accumulator") \
2380 .HostMemory("gradient_squared_accumulator") \
2381 .TypeConstraint<T>("T"), \
2382 ApplyAdagradDAOp<D##Device, T>);
2383
2384REGISTER_KERNELS(CPU, float);
2385REGISTER_KERNELS(CPU, double);
2386#undef REGISTER_KERNELS
2387
2388// Note, this op works on cpu only.
2389template <typename T, typename Tindex>
2390class SparseApplyAdagradDAOp : public OpKernel {
2391 public:
2392 explicit SparseApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
2393 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
2394 }
2395
2396 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
2397 const bool sparse = true;
2398 auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
2399 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
2400 Tensor var;
2401 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
2402 ctx, 0, use_exclusive_lock_, sparse, &var));
2403 Tensor gradient_accum;
2404 OP_REQUIRES_OK(ctx,
2405 GetInputTensorFromVariable<CPUDevice, T>(
2406 ctx, 1, use_exclusive_lock_, sparse, &gradient_accum));
2407 Tensor gradient_squared_accum;
2408 OP_REQUIRES_OK(
2409 ctx, GetInputTensorFromVariable<CPUDevice, T>(
2410 ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum));
2411 OP_REQUIRES(
2412 ctx, var.IsInitialized(),
2413 errors::FailedPrecondition(
2414 "Attempting to use uninitialized variables: ", requested_input(0)));
2415 OP_REQUIRES(
2416 ctx, gradient_accum.IsInitialized(),
2417 errors::FailedPrecondition(
2418 "Attempting to use uninitialized variables: ", requested_input(1)));
2419 OP_REQUIRES(
2420 ctx, gradient_squared_accum.IsInitialized(),
2421 errors::FailedPrecondition(
2422 "Attempting to use uninitialized variables: ", requested_input(2)));
2423 OP_REQUIRES(
2424 ctx, var.shape().IsSameSize(gradient_accum.shape()),
2425 errors::InvalidArgument("var and accum do not have the same shape",
2426 var.shape().DebugString(), " ",
2427 gradient_accum.shape().DebugString()));
2428 OP_REQUIRES(
2429 ctx, var.shape().IsSameSize(gradient_squared_accum.shape()),
2430 errors::InvalidArgument("var and accum do not have the same shape",
2431 var.shape().DebugString(), " ",
2432 gradient_squared_accum.shape().DebugString()));
2433
2434 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
2435 errors::InvalidArgument("var must be at least 1 dimensional"));
2436
2437 const Tensor& grad = ctx->input(3);
2438 const Tensor& indices = ctx->input(4);
2439 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
2440 errors::InvalidArgument("indices must be one-dimensional"));
2441
2442 const Tensor& lr = ctx->input(5);
2443 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
2444 errors::InvalidArgument("lr is not a scalar: ",
2445 lr.shape().DebugString()));
2446
2447 const Tensor& l1 = ctx->input(6);
2448 OP_REQUIRES(
2449 ctx, TensorShapeUtils::IsScalar(l1.shape()),
2450 errors::InvalidArgument("l1 regularization strength is not a scalar: ",
2451 l1.shape().DebugString()));
2452
2453 const Tensor& l2 = ctx->input(7);
2454 OP_REQUIRES(
2455 ctx, TensorShapeUtils::IsScalar(l2.shape()),
2456 errors::InvalidArgument("l2 regularization strength is not a scalar: ",
2457 l2.shape().DebugString()));
2458
2459 const Tensor& global_step = ctx->input(8);
2460 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()),
2461 errors::InvalidArgument("global_step is not a scalar: ",
2462 global_step.shape().DebugString()));
2463
2464 int64_t inner_dim = 1;
2465 for (int d = 1; d < var.dims(); d++) {
2466 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
2467 errors::InvalidArgument(strings::StrCat(
2468 "var and grad must match in dimension ", d)));
2469 inner_dim *= grad.dim_size(d);
2470 }
2471 const Tindex N = indices.dim_size(0);
2472 OP_REQUIRES(
2473 ctx, grad.dim_size(0) == N,
2474 errors::InvalidArgument(
2475 "grad must be the same size as indices in the first dimension."));
2476
2477 OP_REQUIRES(ctx, inner_dim > 0,
2478 errors::InvalidArgument(
2479 "Inner dimension should be greater than zero."));
2480
2481 // AdagradDA update:
2482 // Let g to be gradient accumulator, gg to be gradient squared accumulator,
2483 // T be the global step, lr is the learning rate, and k the initial
2484 // gradient squared accumulator value.
2485 // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
2486 if (N > 0) {
2487 if (inner_dim > 1) {
2488 const Tindex first_dim_size = var.dim_size(0);
2489 auto indices_vec = indices.vec<Tindex>();
2490 auto var_flat = var.flat_outer_dims<T>();
2491 auto gradient_accum_flat = gradient_accum.flat_outer_dims<T>();
2492 auto gradient_squared_accum_flat =
2493 gradient_squared_accum.flat_outer_dims<T>();
2494 auto grad_flat = grad.flat_outer_dims<T>();
2495 T lr_scalar = lr.scalar<T>()();
2496 T global_step_scalar = global_step.scalar<int64_t>()();
2497 T l1_scalar = l1.scalar<T>()();
2498 T l2_scalar = l2.scalar<T>()();
2499 const double gs_lr = global_step_scalar * lr_scalar;
2500
2501 for (Tindex i = 0; i < N; i++) {
2502 const Tindex index = internal::SubtleMustCopy(indices_vec(i));
2503 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
2504 errors::InvalidArgument(
2505 strings::StrCat("Index ", index, " at offset ", i,
2506 " in indices is out of range")));
2507 auto ga = gradient_accum_flat.template chip<0>(index);
2508 auto da = gradient_squared_accum_flat.template chip<0>(index);
2509 auto g = grad_flat.template chip<0>(i);
2510 auto v = var_flat.template chip<0>(index);
2511 ga += g;
2512 da += g.square();
2513 if (l1_scalar > 0) {
2514 v = ga.constant(-1.0) * ga.sign() *
2515 ((ga.abs() / ga.constant(global_step_scalar)) -
2516 ga.constant(l1_scalar))
2517 .cwiseMax(static_cast<T>(0.0)) /
2518 (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr));
2519 } else {
2520 v = ga.constant(-1.0) * (ga / ga.constant(global_step_scalar)) /
2521 (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr));
2522 }
2523 }
2524 } else {
2525 auto indices_vec = indices.vec<Tindex>();
2526 auto var_flat = var.flat<T>();
2527 auto gradient_accum_flat = gradient_accum.flat<T>();
2528 auto gradient_squared_accum_flat = gradient_squared_accum.flat<T>();
2529 auto grad_flat = grad.flat<T>();
2530 const double lr_scalar = lr.scalar<T>()();
2531 const int64_t global_step_scalar = global_step.scalar<int64_t>()();
2532 const double l1_scalar = l1.scalar<T>()();
2533 const double l2_scalar = l2.scalar<T>()();
2534 const Tindex first_dim_size = var_flat.size();
2535 const double gs_l1 = global_step_scalar * l1_scalar;
2536 const double gs_l2_lr = global_step_scalar * l2_scalar * lr_scalar;
2537
2538 for (Tindex i = 0; i < N; i++) {
2539 const Tindex index = internal::SubtleMustCopy(indices_vec(i));
2540 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
2541 errors::InvalidArgument(
2542 strings::StrCat("Index ", index, " at offset ", i,
2543 " in indices is out of range")));
2544 T& ga = gradient_accum_flat(index);
2545 T& da = gradient_squared_accum_flat(index);
2546 const double g = grad_flat(i);
2547 ga += g;
2548 da += g * g;
2549 if (l1_scalar > 0) {
2550 var_flat(index) = sgn(-ga) * lr_scalar *
2551 std::max((std::abs(ga) - gs_l1), 0.0) /
2552 (gs_l2_lr + std::sqrt(da));
2553 } else {
2554 var_flat(index) = (-ga * lr_scalar) / (gs_l2_lr + std::sqrt(da));
2555 }
2556 }
2557 }
2558 }
2559
2560 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
2561 }
2562
2563 private:
2564 bool use_exclusive_lock_;
2565};
2566
2567#define REGISTER_KERNELS(T, Tindices) \
2568 REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradDA") \
2569 .Device(DEVICE_CPU) \
2570 .TypeConstraint<T>("T") \
2571 .TypeConstraint<Tindices>("Tindices"), \
2572 SparseApplyAdagradDAOp<T, Tindices>); \
2573 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradDA") \
2574 .Device(DEVICE_CPU) \
2575 .HostMemory("var") \
2576 .HostMemory("gradient_accumulator") \
2577 .HostMemory("gradient_squared_accumulator") \
2578 .TypeConstraint<T>("T") \
2579 .TypeConstraint<Tindices>("Tindices"), \
2580 SparseApplyAdagradDAOp<T, Tindices>);
2581
2582REGISTER_KERNELS(float, int32);
2583REGISTER_KERNELS(float, int64_t);
2584REGISTER_KERNELS(double, int32);
2585REGISTER_KERNELS(double, int64_t);
2586#undef REGISTER_KERNELS
2587
2588template <typename Device, typename T, bool has_l2_shrinkage>
2589class ApplyFtrlOp : public OpKernel {
2590 public:
2591 explicit ApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
2592 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
2593 OP_REQUIRES_OK(
2594 ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_));
2595 }
2596
2597 void Compute(OpKernelContext* ctx) override {
2598 const bool sparse = false;
2599 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
2600 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
2601
2602 Tensor var;
2603 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2604 ctx, 0, use_exclusive_lock_, sparse, &var));
2605 Tensor accum;
2606 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2607 ctx, 1, use_exclusive_lock_, sparse, &accum));
2608 Tensor linear;
2609 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2610 ctx, 2, use_exclusive_lock_, sparse, &linear));
2611 OP_REQUIRES(
2612 ctx, var.IsInitialized(),
2613 errors::FailedPrecondition(
2614 "Attempting to use uninitialized variables: ", requested_input(0)));
2615 OP_REQUIRES(
2616 ctx, accum.IsInitialized(),
2617 errors::FailedPrecondition(
2618 "Attempting to use uninitialized variables: ", requested_input(1)));
2619 OP_REQUIRES(
2620 ctx, linear.IsInitialized(),
2621 errors::FailedPrecondition(
2622 "Attempting to use uninitialized variables: ", requested_input(2)));
2623
2624 const Tensor& grad = ctx->input(3);
2625 OP_REQUIRES(
2626 ctx, var.shape().IsSameSize(accum.shape()),
2627 errors::InvalidArgument("var and accum do not have the same shape",
2628 var.shape().DebugString(), " ",
2629 accum.shape().DebugString()));
2630 OP_REQUIRES(
2631 ctx, var.shape().IsSameSize(linear.shape()),
2632 errors::InvalidArgument("var and linear do not have the same shape",
2633 var.shape().DebugString(), " ",
2634 linear.shape().DebugString()));
2635 OP_REQUIRES(
2636 ctx, var.shape().IsSameSize(grad.shape()),
2637 errors::InvalidArgument("var and grad do not have the same shape",
2638 var.shape().DebugString(), " ",
2639 grad.shape().DebugString()));
2640
2641 const Tensor& lr = ctx->input(4);
2642 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
2643 errors::InvalidArgument("lr is not a scalar: ",
2644 lr.shape().DebugString()));
2645 const Tensor& l1 = ctx->input(5);
2646 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()),
2647 errors::InvalidArgument("l1 regularization strength is not a "
2648 "scalar: ",
2649 l1.shape().DebugString()));
2650 const Tensor& l2 = ctx->input(6);
2651 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()),
2652 errors::InvalidArgument("l2 regularization strength is not a "
2653 "scalar: ",
2654 l2.shape().DebugString()));
2655 const int lr_power_index = has_l2_shrinkage ? 8 : 7;
2656 const Tensor& lr_power = ctx->input(lr_power_index);
2657 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power.shape()),
2658 errors::InvalidArgument("lr_power is not a scalar",
2659 lr_power.shape().DebugString()));
2660
2661 const Device& device = ctx->template eigen_device<Device>();
2662 if (has_l2_shrinkage) {
2663 const Tensor& l2_shrinkage = ctx->input(7);
2664 OP_REQUIRES(
2665 ctx, TensorShapeUtils::IsScalar(l2_shrinkage.shape()),
2666 errors::InvalidArgument("l2 shrinkage regularization strength "
2667 "is not a scalar: ",
2668 l2_shrinkage.shape().DebugString()));
2669 if (multiply_linear_by_lr_) {
2670 functor::ApplyFtrlV2MultiplyLinearByLr<Device, T>()(
2671 device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(),
2672 grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(),
2673 l2_shrinkage.scalar<T>(), lr_power.scalar<T>());
2674 } else {
2675 functor::ApplyFtrlV2<Device, T>()(
2676 device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(),
2677 grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(),
2678 l2_shrinkage.scalar<T>(), lr_power.scalar<T>());
2679 }
2680 } else if (multiply_linear_by_lr_) {
2681 functor::ApplyFtrlMultiplyLinearByLr<Device, T>()(
2682 device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(),
2683 grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(),
2684 lr_power.scalar<T>());
2685 } else {
2686 functor::ApplyFtrl<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
2687 linear.flat<T>(), grad.flat<T>(),
2688 lr.scalar<T>(), l1.scalar<T>(),
2689 l2.scalar<T>(), lr_power.scalar<T>());
2690 }
2691
2692 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
2693 }
2694
2695 private:
2696 bool use_exclusive_lock_;
2697 bool multiply_linear_by_lr_;
2698};
2699
2700#define REGISTER_KERNELS(D, T) \
2701 REGISTER_KERNEL_BUILDER( \
2702 Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
2703 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>); \
2704 REGISTER_KERNEL_BUILDER( \
2705 Name("ResourceApplyFtrl") \
2706 .HostMemory("var") \
2707 .HostMemory("accum") \
2708 .HostMemory("linear") \
2709 .Device(DEVICE_##D) \
2710 .TypeConstraint<T>("T"), \
2711 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>);
2712#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
2713
2714TF_CALL_half(REGISTER_CPU_KERNELS);
2715TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
2716TF_CALL_float(REGISTER_CPU_KERNELS);
2717TF_CALL_double(REGISTER_CPU_KERNELS);
2718
2719#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2720// Forward declarations of the functor specializations for GPU.
2721namespace functor {
2722#define DECLARE_GPU_SPEC(T) \
2723 template <> \
2724 void ApplyFtrl<GPUDevice, T>::operator()( \
2725 const GPUDevice& d, typename TTypes<T>::Flat var, \
2726 typename TTypes<T>::Flat accum, typename TTypes<T>::Flat linear, \
2727 typename TTypes<T>::ConstFlat grad, typename TTypes<T>::ConstScalar lr, \
2728 typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
2729 typename TTypes<T>::ConstScalar lr_power); \
2730 extern template struct ApplyFtrl<GPUDevice, T>;
2731DECLARE_GPU_SPEC(Eigen::half);
2732DECLARE_GPU_SPEC(float);
2733DECLARE_GPU_SPEC(double);
2734#undef DECLARE_GPU_SPEC
2735} // namespace functor
2736
2737REGISTER_KERNELS(GPU, Eigen::half);
2738REGISTER_KERNELS(GPU, float);
2739REGISTER_KERNELS(GPU, double);
2740#endif
2741#undef REGISTER_CPU_KERNELS
2742#undef REGISTER_KERNELS
2743
2744#define REGISTER_KERNELS(D, T) \
2745 REGISTER_KERNEL_BUILDER( \
2746 Name("ApplyFtrlV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \
2747 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>); \
2748 REGISTER_KERNEL_BUILDER( \
2749 Name("ResourceApplyFtrlV2") \
2750 .HostMemory("var") \
2751 .HostMemory("accum") \
2752 .HostMemory("linear") \
2753 .Device(DEVICE_##D) \
2754 .TypeConstraint<T>("T"), \
2755 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>);
2756#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
2757
2758TF_CALL_half(REGISTER_CPU_KERNELS);
2759TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
2760TF_CALL_float(REGISTER_CPU_KERNELS);
2761TF_CALL_double(REGISTER_CPU_KERNELS);
2762
2763#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2764// Forward declarations of the functor specializations for GPU.
2765namespace functor {
2766#define DECLARE_GPU_SPEC(T) \
2767 template <> \
2768 void ApplyFtrlV2<GPUDevice, T>::operator()( \
2769 const GPUDevice& d, typename TTypes<T>::Flat var, \
2770 typename TTypes<T>::Flat accum, typename TTypes<T>::Flat linear, \
2771 typename TTypes<T>::ConstFlat grad, typename TTypes<T>::ConstScalar lr, \
2772 typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
2773 typename TTypes<T>::ConstScalar l2_shrinkage, \
2774 typename TTypes<T>::ConstScalar lr_power); \
2775 extern template struct ApplyFtrlV2<GPUDevice, T>; \
2776 template <> \
2777 void ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T>::operator()( \
2778 const GPUDevice& d, typename TTypes<T>::Flat var, \
2779 typename TTypes<T>::Flat accum, typename TTypes<T>::Flat linear, \
2780 typename TTypes<T>::ConstFlat grad, typename TTypes<T>::ConstScalar lr, \
2781 typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \
2782 typename TTypes<T>::ConstScalar l2_shrinkage, \
2783 typename TTypes<T>::ConstScalar lr_power); \
2784 extern template struct ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T>;
2785DECLARE_GPU_SPEC(Eigen::half);
2786DECLARE_GPU_SPEC(float);
2787DECLARE_GPU_SPEC(double);
2788#undef DECLARE_GPU_SPEC
2789} // namespace functor
2790
2791REGISTER_KERNELS(GPU, Eigen::half);
2792REGISTER_KERNELS(GPU, float);
2793REGISTER_KERNELS(GPU, double);
2794#endif
2795#undef REGISTER_CPU_KERNELS
2796#undef REGISTER_KERNELS
2797
2798template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
2799class SparseApplyFtrlOp : public OpKernel {
2800 public:
2801 explicit SparseApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
2802 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
2803 OP_REQUIRES_OK(
2804 ctx, ctx->GetAttr("multiply_linear_by_lr", &multiply_linear_by_lr_));
2805 }
2806
2807 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
2808 const bool sparse = true;
2809 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
2810 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
2811 Tensor var;
2812 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2813 ctx, 0, use_exclusive_lock_, sparse, &var));
2814 Tensor accum;
2815 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2816 ctx, 1, use_exclusive_lock_, sparse, &accum));
2817 Tensor linear;
2818 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
2819 ctx, 2, use_exclusive_lock_, sparse, &linear));
2820 OP_REQUIRES(
2821 ctx, var.IsInitialized(),
2822 errors::FailedPrecondition(
2823 "Attempting to use uninitialized variables: ", requested_input(0)));
2824 OP_REQUIRES(
2825 ctx, accum.IsInitialized(),
2826 errors::FailedPrecondition(
2827 "Attempting to use uninitialized variables: ", requested_input(1)));
2828 OP_REQUIRES(
2829 ctx, linear.IsInitialized(),
2830 errors::FailedPrecondition(
2831 "Attempting to use uninitialized variables: ", requested_input(2)));
2832 OP_REQUIRES(
2833 ctx, var.shape().IsSameSize(accum.shape()),
2834 errors::InvalidArgument("var and accum do not have the same shape",
2835 var.shape().DebugString(), " ",
2836 accum.shape().DebugString()));
2837 OP_REQUIRES(
2838 ctx, var.shape().IsSameSize(linear.shape()),
2839 errors::InvalidArgument("var and linear do not have the same shape",
2840 var.shape().DebugString(), " ",
2841 linear.shape().DebugString()));
2842 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
2843 errors::InvalidArgument("var must be at least 1 dimensional"));
2844
2845 const Tensor& grad = ctx->input(3);
2846 const Tensor& indices = ctx->input(4);
2847 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
2848 errors::InvalidArgument("indices must be one-dimensional"));
2849
2850 // Note: The range checks on lr, l1, l2, and lr_power below are disabled
2851 // for non-CPU devices because their values cannot be accessed directly from
2852 // the host. The GPU kernel will not crash if these conditions are not met,
2853 // it will simply produce a bogus answer (possibly inf/nan).
2854 const Tensor& lr = ctx->input(5);
2855 OP_REQUIRES(
2856 ctx,
2857 TensorShapeUtils::IsScalar(lr.shape()) &&
2858 (!std::is_same<Device, CPUDevice>::value ||
2859 lr.scalar<T>()() > static_cast<T>(0) ||
2860 (multiply_linear_by_lr_ && lr.scalar<T>()() >= static_cast<T>(0))),
2861 errors::InvalidArgument("lr is not a positive scalar (or zero if "
2862 "multiply_linear_by_lr is set): ",
2863 lr.shape().DebugString()));
2864
2865 const Tensor& l1 = ctx->input(6);
2866 OP_REQUIRES(ctx,
2867 TensorShapeUtils::IsScalar(l1.shape()) &&
2868 (!std::is_same<Device, CPUDevice>::value ||
2869 l1.scalar<T>()() >= static_cast<T>(0)),
2870 errors::InvalidArgument("l1 regularization strength is not a "
2871 "non-negative scalar: ",
2872 l1.shape().DebugString()));
2873 const Tensor& l2 = ctx->input(7);
2874 OP_REQUIRES(ctx,
2875 TensorShapeUtils::IsScalar(l2.shape()) &&
2876 (!std::is_same<Device, CPUDevice>::value ||
2877 l2.scalar<T>()() >= static_cast<T>(0)),
2878 errors::InvalidArgument("l2 regularization strength is not a "
2879 "non-negative scalar: ",
2880 l2.shape().DebugString()));
2881 const int lr_power_index = has_l2_shrinkage ? 9 : 8;
2882 const Tensor& lr_power = ctx->input(lr_power_index);
2883 OP_REQUIRES(ctx,
2884 TensorShapeUtils::IsScalar(lr_power.shape()) &&
2885 (!std::is_same<Device, CPUDevice>::value ||
2886 lr_power.scalar<T>()() <= static_cast<T>(0)),
2887 errors::InvalidArgument("lr_power is not a "
2888 "non-positive scalar: ",
2889 lr_power.shape().DebugString()));
2890 int64_t inner_dim = 1;
2891 for (int d = 1; d < var.dims(); d++) {
2892 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
2893 errors::InvalidArgument(strings::StrCat(
2894 "var and grad must match in dimension ", d)));
2895 inner_dim *= grad.dim_size(d);
2896 }
2897 const Tindex N = indices.dim_size(0);
2898 OP_REQUIRES(
2899 ctx, grad.dim_size(0) == N,
2900 errors::InvalidArgument(
2901 "grad must be the same size as indices in the first dimension."));
2902
2903 OP_REQUIRES(ctx, inner_dim > 0,
2904 errors::InvalidArgument(
2905 "Inner dimension should be greater than zero."));
2906
2907 const Tensor* l2_shrinkage;
2908 if (has_l2_shrinkage) {
2909 l2_shrinkage = &ctx->input(8);
2910 OP_REQUIRES(
2911 ctx,
2912 TensorShapeUtils::IsScalar(l2_shrinkage->shape()) &&
2913 (!std::is_same<Device, CPUDevice>::value ||
2914 l2_shrinkage->scalar<T>()() >= static_cast<T>(0)),
2915 errors::InvalidArgument("l2 shrinkage regularization strength "
2916 "is not a non-negative scalar: ",
2917 l2_shrinkage->shape().DebugString()));
2918 }
2919
2920 const Device& device = ctx->template eigen_device<Device>();
2921 auto indices_vec = indices.vec<Tindex>();
2922 OP_REQUIRES_OK(
2923 ctx, functor::SparseApplyFtrl<Device, T, Tindex, has_l2_shrinkage>()(
2924 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
2925 linear.flat_outer_dims<T>(), lr.scalar<T>(), l1.scalar<T>(),
2926 l2.scalar<T>(),
2927 // Note: Passing l2 as a placeholder when not has_l2_shrinkage
2928 // (it will not be used).
2929 has_l2_shrinkage ? l2_shrinkage->scalar<T>() : l2.scalar<T>(),
2930 lr_power.scalar<T>(), grad.flat_outer_dims<T>(), indices_vec,
2931 inner_dim, multiply_linear_by_lr_));
2932
2933 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
2934 }
2935
2936 private:
2937 bool use_exclusive_lock_;
2938 bool multiply_linear_by_lr_;
2939};
2940
2941#define REGISTER_KERNELS(D, T, Tindices) \
2942 REGISTER_KERNEL_BUILDER( \
2943 Name("SparseApplyFtrl") \
2944 .Device(DEVICE_##D) \
2945 .TypeConstraint<T>("T") \
2946 .TypeConstraint<Tindices>("Tindices"), \
2947 SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/false>); \
2948 REGISTER_KERNEL_BUILDER( \
2949 Name("ResourceSparseApplyFtrl") \
2950 .Device(DEVICE_##D) \
2951 .TypeConstraint<T>("T") \
2952 .TypeConstraint<Tindices>("Tindices"), \
2953 SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/false>);
2954#define REGISTER_CPU_KERNELS(T) \
2955 REGISTER_KERNELS(CPU, T, int32); \
2956 REGISTER_KERNELS(CPU, T, int64_t);
2957
2958TF_CALL_half(REGISTER_CPU_KERNELS);
2959TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
2960TF_CALL_float(REGISTER_CPU_KERNELS);
2961TF_CALL_double(REGISTER_CPU_KERNELS);
2962
2963#undef REGISTER_CPU_KERNELS
2964
2965#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2966// Forward declarations of the functor specializations for GPU.
2967namespace functor {
2968#define DECLARE_GPU_SPEC(T, Tindex) \
2969 template <> \
2970 Status SparseApplyFtrl<GPUDevice, T, Tindex, /*has_l2_shrinkage=*/false>:: \
2971 operator()( \
2972 const GPUDevice& d, typename TTypes<T>::Matrix var, \
2973 typename TTypes<T>::Matrix accum, typename TTypes<T>::Matrix linear, \
2974 typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar l1, \
2975 typename TTypes<T>::ConstScalar l2, \
2976 typename TTypes<T>::ConstScalar l2_shrinkage, \
2977 typename TTypes<T>::ConstScalar lr_power, \
2978 typename TTypes<T>::ConstMatrix grad, \
2979 typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \
2980 bool multiply_linear_by_lr); \
2981 extern template struct SparseApplyFtrl<GPUDevice, T, Tindex, \
2982 /*has_l2_shrinkage=*/false>;
2983DECLARE_GPU_SPEC(Eigen::half, int32);
2984DECLARE_GPU_SPEC(Eigen::half, int64_t);
2985DECLARE_GPU_SPEC(float, int32);
2986DECLARE_GPU_SPEC(float, int64_t);
2987DECLARE_GPU_SPEC(double, int32);
2988DECLARE_GPU_SPEC(double, int64_t);
2989#undef DECLARE_GPU_SPEC
2990} // namespace functor
2991
2992REGISTER_KERNELS(GPU, Eigen::half, int32);
2993REGISTER_KERNELS(GPU, Eigen::half, int64_t);
2994REGISTER_KERNELS(GPU, float, int32);
2995REGISTER_KERNELS(GPU, float, int64_t);
2996REGISTER_KERNELS(GPU, double, int32);
2997REGISTER_KERNELS(GPU, double, int64_t);
2998#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2999#undef REGISTER_KERNELS
3000
3001#define REGISTER_KERNELS(D, T, Tindices) \
3002 REGISTER_KERNEL_BUILDER( \
3003 Name("SparseApplyFtrlV2") \
3004 .Device(DEVICE_##D) \
3005 .TypeConstraint<T>("T") \
3006 .TypeConstraint<Tindices>("Tindices"), \
3007 SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/true>); \
3008 REGISTER_KERNEL_BUILDER( \
3009 Name("ResourceSparseApplyFtrlV2") \
3010 .Device(DEVICE_##D) \
3011 .TypeConstraint<T>("T") \
3012 .TypeConstraint<Tindices>("Tindices"), \
3013 SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/true>);
3014#define REGISTER_CPU_KERNELS(T) \
3015 REGISTER_KERNELS(CPU, T, int32); \
3016 REGISTER_KERNELS(CPU, T, int64_t);
3017
3018TF_CALL_half(REGISTER_CPU_KERNELS);
3019TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
3020TF_CALL_float(REGISTER_CPU_KERNELS);
3021TF_CALL_double(REGISTER_CPU_KERNELS);
3022
3023#undef REGISTER_CPU_KERNELS
3024
3025#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3026// Forward declarations of the functor specializations for GPU.
3027namespace functor {
3028#define DECLARE_GPU_SPEC(T, Tindex) \
3029 template <> \
3030 Status SparseApplyFtrl<GPUDevice, T, Tindex, /*has_l2_shrinkage=*/true>:: \
3031 operator()( \
3032 const GPUDevice& d, typename TTypes<T>::Matrix var, \
3033 typename TTypes<T>::Matrix accum, typename TTypes<T>::Matrix linear, \
3034 typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar l1, \
3035 typename TTypes<T>::ConstScalar l2, \
3036 typename TTypes<T>::ConstScalar l2_shrinkage, \
3037 typename TTypes<T>::ConstScalar lr_power, \
3038 typename TTypes<T>::ConstMatrix grad, \
3039 typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \
3040 bool multiply_linear_by_lr); \
3041 extern template struct SparseApplyFtrl<GPUDevice, T, Tindex, \
3042 /*has_l2_shrinkage=*/true>;
3043DECLARE_GPU_SPEC(Eigen::half, int32);
3044DECLARE_GPU_SPEC(Eigen::half, int64_t);
3045DECLARE_GPU_SPEC(float, int32);
3046DECLARE_GPU_SPEC(float, int64_t);
3047DECLARE_GPU_SPEC(double, int32);
3048DECLARE_GPU_SPEC(double, int64_t);
3049#undef DECLARE_GPU_SPEC
3050} // namespace functor
3051
3052REGISTER_KERNELS(GPU, Eigen::half, int32);
3053REGISTER_KERNELS(GPU, Eigen::half, int64_t);
3054REGISTER_KERNELS(GPU, float, int32);
3055REGISTER_KERNELS(GPU, float, int64_t);
3056REGISTER_KERNELS(GPU, double, int32);
3057REGISTER_KERNELS(GPU, double, int64_t);
3058#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3059#undef REGISTER_KERNELS
3060
3061template <typename Device, typename T>
3062class ApplyMomentumOp : public OpKernel {
3063 public:
3064 explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
3065 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3066 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
3067 }
3068
3069 void Compute(OpKernelContext* ctx) override {
3070 const bool sparse = false;
3071 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
3072 ctx, use_exclusive_lock_, sparse, {0, 1});
3073
3074 Tensor var;
3075 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3076 ctx, 0, use_exclusive_lock_, sparse, &var));
3077 Tensor accum;
3078 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3079 ctx, 1, use_exclusive_lock_, sparse, &accum));
3080 OP_REQUIRES(
3081 ctx, var.IsInitialized(),
3082 errors::FailedPrecondition(
3083 "Attempting to use uninitialized variables: ", requested_input(0)));
3084 OP_REQUIRES(
3085 ctx, accum.IsInitialized(),
3086 errors::FailedPrecondition(
3087 "Attempting to use uninitialized variables: ", requested_input(1)));
3088 const Tensor& lr = ctx->input(2);
3089 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3090 errors::InvalidArgument("lr is not a scalar: ",
3091 lr.shape().DebugString()));
3092 const Tensor& grad = ctx->input(3);
3093 OP_REQUIRES(
3094 ctx, var.shape().IsSameSize(accum.shape()),
3095 errors::InvalidArgument("var and accum do not have the same shape",
3096 var.shape().DebugString(), " ",
3097 accum.shape().DebugString()));
3098 OP_REQUIRES(
3099 ctx, var.shape().IsSameSize(grad.shape()),
3100 errors::InvalidArgument("var and grad do not have the same shape",
3101 var.shape().DebugString(), " ",
3102 grad.shape().DebugString()));
3103
3104 const Tensor& momentum = ctx->input(4);
3105 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
3106 errors::InvalidArgument("momentum is not a scalar: ",
3107 momentum.shape().DebugString()));
3108
3109 const Device& device = ctx->template eigen_device<Device>();
3110 functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
3111 lr.scalar<T>(), grad.flat<T>(),
3112 momentum.scalar<T>(), use_nesterov_);
3113 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
3114 }
3115
3116 private:
3117 bool use_exclusive_lock_;
3118 bool use_nesterov_;
3119};
3120
3121#define REGISTER_KERNELS(D, T) \
3122 REGISTER_KERNEL_BUILDER( \
3123 Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
3124 ApplyMomentumOp<D##Device, T>); \
3125 REGISTER_KERNEL_BUILDER(Name("ResourceApplyMomentum") \
3126 .Device(DEVICE_##D) \
3127 .HostMemory("var") \
3128 .HostMemory("accum") \
3129 .TypeConstraint<T>("T"), \
3130 ApplyMomentumOp<D##Device, T>);
3131#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
3132
3133TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
3134TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
3135
3136#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3137// Forward declarations of the functor specializations for GPU.
3138namespace functor {
3139#define DECLARE_GPU_SPEC(T) \
3140 template <> \
3141 void ApplyMomentum<GPUDevice, T>::operator()( \
3142 const GPUDevice& d, typename TTypes<T>::Flat var, \
3143 typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
3144 typename TTypes<T>::ConstFlat grad, \
3145 typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \
3146 extern template struct ApplyMomentum<GPUDevice, T>;
3147DECLARE_GPU_SPEC(Eigen::half);
3148DECLARE_GPU_SPEC(float);
3149DECLARE_GPU_SPEC(double);
3150DECLARE_GPU_SPEC(complex64);
3151DECLARE_GPU_SPEC(complex128);
3152#undef DECLARE_GPU_SPEC
3153} // namespace functor
3154
3155REGISTER_KERNELS(GPU, Eigen::half);
3156REGISTER_KERNELS(GPU, float);
3157REGISTER_KERNELS(GPU, double);
3158REGISTER_KERNELS(GPU, complex64);
3159REGISTER_KERNELS(GPU, complex128);
3160#endif
3161#undef REGISTER_CPU_KERNELS
3162#undef REGISTER_KERNELS
3163
3164// Note, this op works on cpu only.
3165template <typename T, typename Tindex>
3166class SparseApplyMomentumOp : public OpKernel {
3167 public:
3168 explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
3169 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3170 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
3171 }
3172
3173 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
3174 const bool sparse = true;
3175 auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
3176 ctx, use_exclusive_lock_, sparse, {0, 1});
3177
3178 Tensor var;
3179 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
3180 ctx, 0, use_exclusive_lock_, sparse, &var));
3181 Tensor accum;
3182 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
3183 ctx, 1, use_exclusive_lock_, sparse, &accum));
3184 OP_REQUIRES(
3185 ctx, var.IsInitialized(),
3186 errors::FailedPrecondition(
3187 "Attempting to use uninitialized variables: ", requested_input(0)));
3188 OP_REQUIRES(
3189 ctx, accum.IsInitialized(),
3190 errors::FailedPrecondition(
3191 "Attempting to use uninitialized variables: ", requested_input(1)));
3192 OP_REQUIRES(
3193 ctx, var.shape().IsSameSize(accum.shape()),
3194 errors::InvalidArgument("var and accum do not have the same shape",
3195 var.shape().DebugString(), " ",
3196 accum.shape().DebugString()));
3197 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
3198 errors::InvalidArgument("var must be at least 1 dimensional"));
3199
3200 const Tensor& lr = ctx->input(2);
3201 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3202 errors::InvalidArgument("lr is not a scalar : ",
3203 lr.shape().DebugString()));
3204 const Tensor& grad = ctx->input(3);
3205 const Tensor& indices = ctx->input(4);
3206 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
3207 errors::InvalidArgument("indices must be one-dimensional"));
3208
3209 for (int d = 1; d < var.dims(); d++) {
3210 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
3211 errors::InvalidArgument(strings::StrCat(
3212 "var and grad must match in dimension ", d)));
3213 }
3214 const Tindex N = indices.dim_size(0);
3215 OP_REQUIRES(
3216 ctx, grad.dim_size(0) == N,
3217 errors::InvalidArgument(
3218 "grad must be the same size as indices in the first dimension."));
3219
3220 const Tensor& momentum = ctx->input(5);
3221 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
3222 errors::InvalidArgument("momentum is not a scalar: ",
3223 momentum.shape().DebugString()));
3224
3225 if (N > 0) {
3226 const Tindex first_dim_size = var.dim_size(0);
3227 auto indices_vec = indices.vec<Tindex>();
3228 auto var_flat = var.flat_outer_dims<T>();
3229 auto accum_flat = accum.flat_outer_dims<T>();
3230 auto grad_flat = grad.flat_outer_dims<T>();
3231 T lr_scalar = lr.scalar<T>()();
3232 T momentum_scalar = momentum.scalar<T>()();
3233
3234 for (Tindex i = 0; i < N; i++) {
3235 const Tindex index = internal::SubtleMustCopy(indices_vec(i));
3236 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
3237 errors::InvalidArgument(
3238 strings::StrCat("Index ", index, " at offset ", i,
3239 " in indices is out of range")));
3240 auto a = accum_flat.template chip<0>(index);
3241 auto g = grad_flat.template chip<0>(i);
3242 auto v = var_flat.template chip<0>(index);
3243 a = a * a.constant(momentum_scalar) + g;
3244 if (use_nesterov_) {
3245 v -= g.constant(lr_scalar) * g +
3246 a.constant(lr_scalar) * a.constant(momentum_scalar) * a;
3247 } else {
3248 v -= a.constant(lr_scalar) * a;
3249 }
3250 }
3251 }
3252
3253 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
3254 }
3255
3256 private:
3257 bool use_exclusive_lock_;
3258 bool use_nesterov_;
3259};
3260
3261#define REGISTER_KERNELS(T, Tindices) \
3262 REGISTER_KERNEL_BUILDER(Name("SparseApplyMomentum") \
3263 .Device(DEVICE_CPU) \
3264 .TypeConstraint<T>("T") \
3265 .TypeConstraint<Tindices>("Tindices"), \
3266 SparseApplyMomentumOp<T, Tindices>); \
3267 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyMomentum") \
3268 .Device(DEVICE_CPU) \
3269 .TypeConstraint<T>("T") \
3270 .TypeConstraint<Tindices>("Tindices"), \
3271 SparseApplyMomentumOp<T, Tindices>);
3272#define REGISTER_CPU_KERNELS(T) \
3273 REGISTER_KERNELS(T, int32); \
3274 REGISTER_KERNELS(T, int64_t);
3275
3276TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
3277TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
3278
3279#undef REGISTER_CPU_KERNELS
3280#undef REGISTER_KERNELS
3281
3282template <typename Device, typename T>
3283class ApplyKerasMomentumOp : public OpKernel {
3284 public:
3285 explicit ApplyKerasMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
3286 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3287 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
3288 }
3289
3290 void Compute(OpKernelContext* ctx) override {
3291 const bool sparse = false;
3292 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
3293 ctx, use_exclusive_lock_, sparse, {0, 1});
3294
3295 Tensor var;
3296 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3297 ctx, 0, use_exclusive_lock_, sparse, &var));
3298 Tensor accum;
3299 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3300 ctx, 1, use_exclusive_lock_, sparse, &accum));
3301 OP_REQUIRES(
3302 ctx, var.IsInitialized(),
3303 errors::FailedPrecondition(
3304 "Attempting to use uninitialized variables: ", requested_input(0)));
3305 OP_REQUIRES(
3306 ctx, accum.IsInitialized(),
3307 errors::FailedPrecondition(
3308 "Attempting to use uninitialized variables: ", requested_input(1)));
3309 const Tensor& lr = ctx->input(2);
3310 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3311 errors::InvalidArgument("lr is not a scalar: ",
3312 lr.shape().DebugString()));
3313 const Tensor& grad = ctx->input(3);
3314 OP_REQUIRES(
3315 ctx, var.shape().IsSameSize(accum.shape()),
3316 errors::InvalidArgument("var and accum do not have the same shape",
3317 var.shape().DebugString(), " ",
3318 accum.shape().DebugString()));
3319 OP_REQUIRES(
3320 ctx, var.shape().IsSameSize(grad.shape()),
3321 errors::InvalidArgument("var and grad do not have the same shape",
3322 var.shape().DebugString(), " ",
3323 grad.shape().DebugString()));
3324
3325 const Tensor& momentum = ctx->input(4);
3326 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
3327 errors::InvalidArgument("momentum is not a scalar: ",
3328 momentum.shape().DebugString()));
3329
3330 const Device& device = ctx->template eigen_device<Device>();
3331 functor::ApplyKerasMomentum<Device, T>()(
3332 device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), grad.flat<T>(),
3333 momentum.scalar<T>(), use_nesterov_);
3334 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
3335 }
3336
3337 private:
3338 bool use_exclusive_lock_;
3339 bool use_nesterov_;
3340};
3341
3342#define REGISTER_KERNELS(D, T) \
3343 REGISTER_KERNEL_BUILDER(Name("ResourceApplyKerasMomentum") \
3344 .Device(DEVICE_##D) \
3345 .HostMemory("var") \
3346 .HostMemory("accum") \
3347 .TypeConstraint<T>("T"), \
3348 ApplyKerasMomentumOp<D##Device, T>);
3349#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
3350
3351TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
3352TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
3353
3354#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3355// Forward declarations of the functor specializations for GPU.
3356namespace functor {
3357#define DECLARE_GPU_SPEC(T) \
3358 template <> \
3359 void ApplyKerasMomentum<GPUDevice, T>::operator()( \
3360 const GPUDevice& d, typename TTypes<T>::Flat var, \
3361 typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
3362 typename TTypes<T>::ConstFlat grad, \
3363 typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \
3364 extern template struct ApplyKerasMomentum<GPUDevice, T>;
3365DECLARE_GPU_SPEC(Eigen::half);
3366DECLARE_GPU_SPEC(float);
3367DECLARE_GPU_SPEC(double);
3368DECLARE_GPU_SPEC(complex64);
3369DECLARE_GPU_SPEC(complex128);
3370#undef DECLARE_GPU_SPEC
3371} // namespace functor
3372
3373REGISTER_KERNELS(GPU, Eigen::half);
3374REGISTER_KERNELS(GPU, float);
3375REGISTER_KERNELS(GPU, double);
3376REGISTER_KERNELS(GPU, complex64);
3377REGISTER_KERNELS(GPU, complex128);
3378#endif
3379#undef REGISTER_CPU_KERNELS
3380#undef REGISTER_KERNELS
3381
3382// Note, this op works on cpu only.
3383template <typename T, typename Device, typename Tindex>
3384class SparseApplyKerasMomentumOp : public OpKernel {
3385 public:
3386 explicit SparseApplyKerasMomentumOp(OpKernelConstruction* ctx)
3387 : OpKernel(ctx) {
3388 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3389 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
3390 }
3391
3392 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
3393 const bool sparse = true;
3394 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
3395 ctx, use_exclusive_lock_, sparse, {0, 1});
3396
3397 Tensor var;
3398 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3399 ctx, 0, use_exclusive_lock_, sparse, &var));
3400 Tensor accum;
3401 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3402 ctx, 1, use_exclusive_lock_, sparse, &accum));
3403 OP_REQUIRES(
3404 ctx, var.IsInitialized(),
3405 errors::FailedPrecondition(
3406 "Attempting to use uninitialized variables: ", requested_input(0)));
3407 OP_REQUIRES(
3408 ctx, accum.IsInitialized(),
3409 errors::FailedPrecondition(
3410 "Attempting to use uninitialized variables: ", requested_input(1)));
3411 OP_REQUIRES(
3412 ctx, var.shape().IsSameSize(accum.shape()),
3413 errors::InvalidArgument("var and accum do not have the same shape",
3414 var.shape().DebugString(), " ",
3415 accum.shape().DebugString()));
3416 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
3417 errors::InvalidArgument("var must be at least 1 dimensional"));
3418
3419 const Tensor& lr = ctx->input(2);
3420 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3421 errors::InvalidArgument("lr is not a scalar : ",
3422 lr.shape().DebugString()));
3423 const Tensor& grad = ctx->input(3);
3424 const Tensor& indices = ctx->input(4);
3425 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
3426 errors::InvalidArgument("indices must be one-dimensional"));
3427
3428 for (int d = 1; d < var.dims(); d++) {
3429 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
3430 errors::InvalidArgument(strings::StrCat(
3431 "var and grad must match in dimension ", d)));
3432 }
3433 const Tindex N = indices.dim_size(0);
3434 OP_REQUIRES(
3435 ctx, grad.dim_size(0) == N,
3436 errors::InvalidArgument(
3437 "grad must be the same size as indices in the first dimension."));
3438
3439 const Tensor& momentum = ctx->input(5);
3440 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
3441 errors::InvalidArgument("momentum is not a scalar: ",
3442 momentum.shape().DebugString()));
3443
3444 const Device& device = ctx->template eigen_device<Device>();
3445 auto indices_flat = indices.flat<Tindex>();
3446 const Tindex bad_i = functor::SparseApplyKerasMomentum<Device, T, Tindex>()(
3447 device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(),
3448 lr.scalar<T>(), grad.flat_outer_dims<T>(), indices_flat,
3449 momentum.scalar<T>(), use_nesterov_);
3450 OP_REQUIRES(
3451 ctx, bad_i < 0,
3452 errors::InvalidArgument(
3453 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
3454 indices_flat(bad_i), " is not in [0, ", var.dim_size(0), ")"));
3455
3456 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
3457 }
3458
3459 private:
3460 bool use_exclusive_lock_;
3461 bool use_nesterov_;
3462};
3463
3464#define REGISTER_KERNELS(T, D, Tindices) \
3465 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyKerasMomentum") \
3466 .Device(DEVICE_##D) \
3467 .TypeConstraint<T>("T") \
3468 .TypeConstraint<Tindices>("Tindices"), \
3469 SparseApplyKerasMomentumOp<T, D##Device, Tindices>);
3470#define REGISTER_CPU_KERNELS(T) \
3471 REGISTER_KERNELS(T, CPU, int32); \
3472 REGISTER_KERNELS(T, CPU, int64_t);
3473
3474TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
3475TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
3476
3477#undef REGISTER_CPU_KERNELS
3478
3479#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3480// Forward declarations of the functor specializations for GPU.
3481namespace functor {
3482#define DECLARE_GPU_SPEC(T, Tindex) \
3483 template <> \
3484 Tindex SparseApplyKerasMomentum<GPUDevice, T, Tindex>::operator()( \
3485 const GPUDevice& d, typename TTypes<T>::Matrix var, \
3486 typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \
3487 typename TTypes<T>::ConstMatrix grad, \
3488 typename TTypes<Tindex>::ConstFlat indices, \
3489 typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \
3490 extern template struct SparseApplyKerasMomentum<GPUDevice, T, Tindex>;
3491DECLARE_GPU_SPEC(Eigen::half, int32);
3492DECLARE_GPU_SPEC(Eigen::half, int64_t);
3493DECLARE_GPU_SPEC(float, int32);
3494DECLARE_GPU_SPEC(float, int64_t);
3495DECLARE_GPU_SPEC(double, int32);
3496DECLARE_GPU_SPEC(double, int64_t);
3497DECLARE_GPU_SPEC(complex64, int32);
3498DECLARE_GPU_SPEC(complex64, int64_t);
3499DECLARE_GPU_SPEC(complex128, int32);
3500DECLARE_GPU_SPEC(complex128, int64_t);
3501#undef DECLARE_GPU_SPEC
3502} // namespace functor
3503
3504#define REGISTER_GPU_KERNELS(T) \
3505 REGISTER_KERNELS(T, GPU, int32); \
3506 REGISTER_KERNELS(T, GPU, int64_t);
3507
3508REGISTER_GPU_KERNELS(Eigen::half);
3509REGISTER_GPU_KERNELS(float);
3510REGISTER_GPU_KERNELS(double);
3511REGISTER_GPU_KERNELS(complex64);
3512REGISTER_GPU_KERNELS(complex128);
3513#undef REGISTER_GPU_KERNELS
3514#endif
3515#undef REGISTER_KERNELS
3516
3517template <typename Device, typename T>
3518class ApplyAdamOp : public OpKernel {
3519 public:
3520 explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
3521 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3522 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
3523 }
3524
3525 void Compute(OpKernelContext* ctx) override {
3526 const bool sparse = false;
3527 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
3528 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
3529
3530 Tensor var;
3531 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3532 ctx, 0, use_exclusive_lock_, sparse, &var));
3533 Tensor m;
3534 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3535 ctx, 1, use_exclusive_lock_, sparse, &m));
3536 Tensor v;
3537 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3538 ctx, 2, use_exclusive_lock_, sparse, &v));
3539 OP_REQUIRES(
3540 ctx, var.IsInitialized(),
3541 errors::FailedPrecondition(
3542 "Attempting to use uninitialized variables: ", requested_input(0)));
3543 OP_REQUIRES(
3544 ctx, m.IsInitialized(),
3545 errors::FailedPrecondition(
3546 "Attempting to use uninitialized variables: ", requested_input(1)));
3547 OP_REQUIRES(
3548 ctx, v.IsInitialized(),
3549 errors::FailedPrecondition(
3550 "Attempting to use uninitialized variables: ", requested_input(2)));
3551
3552 const Tensor& beta1_power = ctx->input(3);
3553 const Tensor& beta2_power = ctx->input(4);
3554 const Tensor& lr = ctx->input(5);
3555 const Tensor& beta1 = ctx->input(6);
3556 const Tensor& beta2 = ctx->input(7);
3557 const Tensor& epsilon = ctx->input(8);
3558
3559 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
3560 errors::InvalidArgument("beta1_power is not a scalar: ",
3561 beta1_power.shape().DebugString()));
3562 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()),
3563 errors::InvalidArgument("beta2_power is not a scalar: ",
3564 beta2_power.shape().DebugString()));
3565 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3566 errors::InvalidArgument("lr is not a scalar : ",
3567 lr.shape().DebugString()));
3568 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
3569 errors::InvalidArgument("beta1 is not a scalar: ",
3570 beta1.shape().DebugString()));
3571 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
3572 errors::InvalidArgument("beta2 is not a scalar: ",
3573 beta2.shape().DebugString()));
3574 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
3575 errors::InvalidArgument("epsilon is not a scalar: ",
3576 epsilon.shape().DebugString()));
3577
3578 const Tensor& grad = ctx->input(9);
3579 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
3580 errors::InvalidArgument("var and m do not have the same shape",
3581 var.shape().DebugString(), " ",
3582 m.shape().DebugString()));
3583 OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
3584 errors::InvalidArgument("var and v do not have the same shape",
3585 var.shape().DebugString(), " ",
3586 v.shape().DebugString()));
3587 OP_REQUIRES(
3588 ctx, var.shape().IsSameSize(grad.shape()),
3589 errors::InvalidArgument("var and grad do not have the same shape",
3590 var.shape().DebugString(), " ",
3591 grad.shape().DebugString()));
3592
3593 const Device& device = ctx->template eigen_device<Device>();
3594 functor::ApplyAdam<Device, T>()(
3595 device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
3596 beta1_power.scalar<T>(), beta2_power.scalar<T>(), lr.scalar<T>(),
3597 beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
3598 grad.flat<T>(), use_nesterov_);
3599
3600 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
3601 }
3602
3603 private:
3604 bool use_exclusive_lock_;
3605 bool use_nesterov_;
3606};
3607
3608#define REGISTER_KERNELS(D, T) \
3609 REGISTER_KERNEL_BUILDER( \
3610 Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
3611 ApplyAdamOp<D##Device, T>); \
3612 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdam") \
3613 .HostMemory("var") \
3614 .HostMemory("m") \
3615 .HostMemory("v") \
3616 .Device(DEVICE_##D) \
3617 .TypeConstraint<T>("T"), \
3618 ApplyAdamOp<D##Device, T>);
3619#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
3620
3621TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
3622TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
3623
3624#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3625// Forward declarations of the functor specializations for GPU.
3626namespace functor {
3627#define DECLARE_GPU_SPEC(T) \
3628 template <> \
3629 void ApplyAdam<GPUDevice, T>::operator()( \
3630 const GPUDevice& d, typename TTypes<T>::Flat var, \
3631 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
3632 typename TTypes<T>::ConstScalar beta1_power, \
3633 typename TTypes<T>::ConstScalar beta2_power, \
3634 typename TTypes<T>::ConstScalar lr, \
3635 typename TTypes<T>::ConstScalar beta1, \
3636 typename TTypes<T>::ConstScalar beta2, \
3637 typename TTypes<T>::ConstScalar epsilon, \
3638 typename TTypes<T>::ConstFlat grad, bool use_nesterov); \
3639 extern template struct ApplyAdam<GPUDevice, T>;
3640DECLARE_GPU_SPEC(Eigen::half);
3641DECLARE_GPU_SPEC(float);
3642DECLARE_GPU_SPEC(double);
3643DECLARE_GPU_SPEC(complex64);
3644DECLARE_GPU_SPEC(complex128);
3645#undef DECLARE_GPU_SPEC
3646} // namespace functor
3647
3648REGISTER_KERNELS(GPU, Eigen::half);
3649REGISTER_KERNELS(GPU, float);
3650REGISTER_KERNELS(GPU, double);
3651REGISTER_KERNELS(GPU, complex64);
3652REGISTER_KERNELS(GPU, complex128);
3653#endif
3654#undef REGISTER_CPU_KERNELS
3655#undef REGISTER_KERNELS
3656
3657template <typename Device, typename T>
3658class ApplyAdamWithAmsgradOp : public OpKernel {
3659 public:
3660 explicit ApplyAdamWithAmsgradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
3661 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3662 }
3663
3664 void Compute(OpKernelContext* ctx) override {
3665 const bool sparse = false;
3666 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
3667 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
3668
3669 Tensor var;
3670 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3671 ctx, 0, use_exclusive_lock_, sparse, &var));
3672 Tensor m;
3673 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3674 ctx, 1, use_exclusive_lock_, sparse, &m));
3675 Tensor v;
3676 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3677 ctx, 2, use_exclusive_lock_, sparse, &v));
3678 Tensor vhat;
3679 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3680 ctx, 3, use_exclusive_lock_, sparse, &vhat));
3681 OP_REQUIRES(
3682 ctx, var.IsInitialized(),
3683 errors::FailedPrecondition(
3684 "Attempting to use uninitialized variables: ", requested_input(0)));
3685 OP_REQUIRES(
3686 ctx, m.IsInitialized(),
3687 errors::FailedPrecondition(
3688 "Attempting to use uninitialized variables: ", requested_input(1)));
3689 OP_REQUIRES(
3690 ctx, v.IsInitialized(),
3691 errors::FailedPrecondition(
3692 "Attempting to use uninitialized variables: ", requested_input(2)));
3693 OP_REQUIRES(
3694 ctx, vhat.IsInitialized(),
3695 errors::FailedPrecondition(
3696 "Attempting to use uninitialized variables: ", requested_input(2)));
3697
3698 const Tensor& beta1_power = ctx->input(4);
3699 const Tensor& beta2_power = ctx->input(5);
3700 const Tensor& lr = ctx->input(6);
3701 const Tensor& beta1 = ctx->input(7);
3702 const Tensor& beta2 = ctx->input(8);
3703 const Tensor& epsilon = ctx->input(9);
3704
3705 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
3706 errors::InvalidArgument("beta1_power is not a scalar: ",
3707 beta1_power.shape().DebugString()));
3708 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()),
3709 errors::InvalidArgument("beta2_power is not a scalar: ",
3710 beta2_power.shape().DebugString()));
3711 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3712 errors::InvalidArgument("lr is not a scalar : ",
3713 lr.shape().DebugString()));
3714 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
3715 errors::InvalidArgument("beta1 is not a scalar: ",
3716 beta1.shape().DebugString()));
3717 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
3718 errors::InvalidArgument("beta2 is not a scalar: ",
3719 beta2.shape().DebugString()));
3720 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
3721 errors::InvalidArgument("epsilon is not a scalar: ",
3722 epsilon.shape().DebugString()));
3723
3724 const Tensor& grad = ctx->input(10);
3725 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
3726 errors::InvalidArgument("var and m do not have the same shape",
3727 var.shape().DebugString(), " ",
3728 m.shape().DebugString()));
3729 OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
3730 errors::InvalidArgument("var and v do not have the same shape",
3731 var.shape().DebugString(), " ",
3732 v.shape().DebugString()));
3733 OP_REQUIRES(
3734 ctx, var.shape().IsSameSize(grad.shape()),
3735 errors::InvalidArgument("var and grad do not have the same shape",
3736 var.shape().DebugString(), " ",
3737 grad.shape().DebugString()));
3738
3739 const Device& device = ctx->template eigen_device<Device>();
3740 functor::ApplyAdamWithAmsgrad<Device, T>()(
3741 device, var.flat<T>(), m.flat<T>(), v.flat<T>(), vhat.flat<T>(),
3742 beta1_power.scalar<T>(), beta2_power.scalar<T>(), lr.scalar<T>(),
3743 beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
3744 grad.flat<T>());
3745
3746 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
3747 }
3748
3749 private:
3750 bool use_exclusive_lock_;
3751};
3752
3753#define REGISTER_KERNELS(D, T) \
3754 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdamWithAmsgrad") \
3755 .HostMemory("var") \
3756 .HostMemory("m") \
3757 .HostMemory("v") \
3758 .HostMemory("vhat") \
3759 .Device(DEVICE_##D) \
3760 .TypeConstraint<T>("T"), \
3761 ApplyAdamWithAmsgradOp<D##Device, T>);
3762#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
3763
3764TF_CALL_half(REGISTER_CPU_KERNELS);
3765TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
3766TF_CALL_float(REGISTER_CPU_KERNELS);
3767TF_CALL_double(REGISTER_CPU_KERNELS);
3768
3769#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3770// Forward declarations of the functor specializations for GPU.
3771namespace functor {
3772#define DECLARE_GPU_SPEC(T) \
3773 template <> \
3774 void ApplyAdamWithAmsgrad<GPUDevice, T>::operator()( \
3775 const GPUDevice& d, typename TTypes<T>::Flat var, \
3776 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
3777 typename TTypes<T>::Flat vhat, \
3778 typename TTypes<T>::ConstScalar beta1_power, \
3779 typename TTypes<T>::ConstScalar beta2_power, \
3780 typename TTypes<T>::ConstScalar lr, \
3781 typename TTypes<T>::ConstScalar beta1, \
3782 typename TTypes<T>::ConstScalar beta2, \
3783 typename TTypes<T>::ConstScalar epsilon, \
3784 typename TTypes<T>::ConstFlat grad); \
3785 extern template struct ApplyAdamWithAmsgrad<GPUDevice, T>;
3786DECLARE_GPU_SPEC(Eigen::half);
3787DECLARE_GPU_SPEC(float);
3788DECLARE_GPU_SPEC(double);
3789#undef DECLARE_GPU_SPEC
3790} // namespace functor
3791
3792REGISTER_KERNELS(GPU, Eigen::half);
3793REGISTER_KERNELS(GPU, float);
3794REGISTER_KERNELS(GPU, double);
3795#endif
3796#undef REGISTER_CPU_KERNELS
3797#undef REGISTER_KERNELS
3798
3799template <typename Device, typename T>
3800class ApplyAdaMaxOp : public OpKernel {
3801 public:
3802 explicit ApplyAdaMaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
3803 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3804 }
3805
3806 void Compute(OpKernelContext* ctx) override {
3807 const bool sparse = false;
3808 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
3809 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
3810
3811 Tensor var;
3812 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3813 ctx, 0, use_exclusive_lock_, sparse, &var));
3814 Tensor m;
3815 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3816 ctx, 1, use_exclusive_lock_, sparse, &m));
3817 Tensor v;
3818 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3819 ctx, 2, use_exclusive_lock_, sparse, &v));
3820 OP_REQUIRES(
3821 ctx, var.IsInitialized(),
3822 errors::FailedPrecondition(
3823 "Attempting to use uninitialized variables: ", requested_input(0)));
3824 OP_REQUIRES(
3825 ctx, m.IsInitialized(),
3826 errors::FailedPrecondition(
3827 "Attempting to use uninitialized variables: ", requested_input(1)));
3828 OP_REQUIRES(
3829 ctx, v.IsInitialized(),
3830 errors::FailedPrecondition(
3831 "Attempting to use uninitialized variables: ", requested_input(2)));
3832
3833 const Tensor& beta1_power = ctx->input(3);
3834 const Tensor& lr = ctx->input(4);
3835 const Tensor& beta1 = ctx->input(5);
3836 const Tensor& beta2 = ctx->input(6);
3837 const Tensor& epsilon = ctx->input(7);
3838
3839 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
3840 errors::InvalidArgument("beta1_power is not a scalar: ",
3841 beta1_power.shape().DebugString()));
3842 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3843 errors::InvalidArgument("lr is not a scalar : ",
3844 lr.shape().DebugString()));
3845 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
3846 errors::InvalidArgument("beta1 is not a scalar: ",
3847 beta1.shape().DebugString()));
3848 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
3849 errors::InvalidArgument("beta2 is not a scalar: ",
3850 beta2.shape().DebugString()));
3851 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
3852 errors::InvalidArgument("epsilon is not a scalar: ",
3853 epsilon.shape().DebugString()));
3854
3855 const Tensor& grad = ctx->input(8);
3856 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
3857 errors::InvalidArgument("var and m do not have the same shape",
3858 var.shape().DebugString(), " ",
3859 m.shape().DebugString()));
3860 OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
3861 errors::InvalidArgument("var and v do not have the same shape",
3862 var.shape().DebugString(), " ",
3863 v.shape().DebugString()));
3864 OP_REQUIRES(
3865 ctx, var.shape().IsSameSize(grad.shape()),
3866 errors::InvalidArgument("var and grad do not have the same shape",
3867 var.shape().DebugString(), " ",
3868 grad.shape().DebugString()));
3869
3870 const Device& device = ctx->template eigen_device<Device>();
3871 functor::ApplyAdaMax<Device, T>()(
3872 device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
3873 beta1_power.scalar<T>(), lr.scalar<T>(), beta1.scalar<T>(),
3874 beta2.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>());
3875
3876 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
3877 }
3878
3879 private:
3880 bool use_exclusive_lock_;
3881};
3882
3883#define REGISTER_KERNELS(D, T) \
3884 REGISTER_KERNEL_BUILDER( \
3885 Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \
3886 ApplyAdaMaxOp<D##Device, T>); \
3887 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \
3888 .HostMemory("var") \
3889 .HostMemory("m") \
3890 .HostMemory("v") \
3891 .Device(DEVICE_##D) \
3892 .TypeConstraint<T>("T"), \
3893 ApplyAdaMaxOp<D##Device, T>);
3894#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
3895
3896TF_CALL_half(REGISTER_CPU_KERNELS);
3897TF_CALL_float(REGISTER_CPU_KERNELS);
3898TF_CALL_double(REGISTER_CPU_KERNELS);
3899
3900#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
3901// Forward declarations of the functor specializations for GPU.
3902namespace functor {
3903#define DECLARE_GPU_SPEC(T) \
3904 template <> \
3905 void ApplyAdaMax<GPUDevice, T>::operator()( \
3906 const GPUDevice& d, typename TTypes<T>::Flat var, \
3907 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
3908 typename TTypes<T>::ConstScalar beta1_power, \
3909 typename TTypes<T>::ConstScalar lr, \
3910 typename TTypes<T>::ConstScalar beta1, \
3911 typename TTypes<T>::ConstScalar beta2, \
3912 typename TTypes<T>::ConstScalar epsilon, \
3913 typename TTypes<T>::ConstFlat grad); \
3914 extern template struct ApplyAdaMax<GPUDevice, T>;
3915DECLARE_GPU_SPEC(Eigen::half);
3916DECLARE_GPU_SPEC(float);
3917DECLARE_GPU_SPEC(double);
3918#undef DECLARE_GPU_SPEC
3919} // namespace functor
3920
3921REGISTER_KERNELS(GPU, Eigen::half);
3922REGISTER_KERNELS(GPU, float);
3923REGISTER_KERNELS(GPU, double);
3924#endif
3925#undef REGISTER_CPU_KERNELS
3926#undef REGISTER_KERNELS
3927
3928template <typename Device, typename T>
3929class ApplyRMSPropOp : public OpKernel {
3930 public:
3931 explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
3932 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
3933 }
3934
3935 void Compute(OpKernelContext* ctx) override {
3936 const bool sparse = false;
3937 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
3938 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
3939
3940 Tensor var;
3941 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3942 ctx, 0, use_exclusive_lock_, sparse, &var));
3943 Tensor ms;
3944 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3945 ctx, 1, use_exclusive_lock_, sparse, &ms));
3946 Tensor mom;
3947 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
3948 ctx, 2, use_exclusive_lock_, sparse, &mom));
3949
3950 OP_REQUIRES(
3951 ctx, var.IsInitialized(),
3952 errors::FailedPrecondition(
3953 "Attempting to use uninitialized variables: ", requested_input(0)));
3954 OP_REQUIRES(
3955 ctx, ms.IsInitialized(),
3956 errors::FailedPrecondition(
3957 "Attempting to use uninitialized variables: ", requested_input(1)));
3958 OP_REQUIRES(
3959 ctx, mom.IsInitialized(),
3960 errors::FailedPrecondition(
3961 "Attempting to use uninitialized variables: ", requested_input(2)));
3962
3963 const Tensor& lr = ctx->input(3);
3964 const Tensor& rho = ctx->input(4);
3965 const Tensor& momentum = ctx->input(5);
3966 const Tensor& epsilon = ctx->input(6);
3967 const Tensor& grad = ctx->input(7);
3968
3969 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
3970 errors::InvalidArgument("lr is not a scalar : ",
3971 lr.shape().DebugString()));
3972 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
3973 errors::InvalidArgument("rho is not a scalar: ",
3974 rho.shape().DebugString()));
3975 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
3976 errors::InvalidArgument("momentum is not a scalar: ",
3977 momentum.shape().DebugString()));
3978 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
3979 errors::InvalidArgument("epsilon is not a scalar: ",
3980 epsilon.shape().DebugString()));
3981
3982 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
3983 errors::InvalidArgument("var and ms do not have the same shape",
3984 var.shape().DebugString(), " ",
3985 ms.shape().DebugString()));
3986
3987 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
3988 errors::InvalidArgument(
3989 "var and mom do not have the same shape",
3990 var.shape().DebugString(), " ", mom.shape().DebugString()));
3991
3992 OP_REQUIRES(
3993 ctx, var.shape().IsSameSize(grad.shape()),
3994 errors::InvalidArgument("var and grad do not have the same shape",
3995 var.shape().DebugString(), " ",
3996 grad.shape().DebugString()));
3997
3998 const Device& device = ctx->template eigen_device<Device>();
3999 functor::ApplyRMSProp<Device, T>()(device, var.flat<T>(), ms.flat<T>(),
4000 mom.flat<T>(), lr.scalar<T>(),
4001 rho.scalar<T>(), momentum.scalar<T>(),
4002 epsilon.scalar<T>(), grad.flat<T>());
4003
4004 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
4005 }
4006
4007 private:
4008 bool use_exclusive_lock_;
4009};
4010
4011template <typename Device, typename T>
4012class ApplyCenteredRMSPropOp : public OpKernel {
4013 public:
4014 explicit ApplyCenteredRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
4015 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
4016 }
4017
4018 void Compute(OpKernelContext* ctx) override {
4019 const bool sparse = false;
4020 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
4021 ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3});
4022
4023 Tensor var;
4024 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4025 ctx, 0, use_exclusive_lock_, sparse, &var));
4026 Tensor mg;
4027 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4028 ctx, 1, use_exclusive_lock_, sparse, &mg));
4029 Tensor ms;
4030 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4031 ctx, 2, use_exclusive_lock_, sparse, &ms));
4032 Tensor mom;
4033 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4034 ctx, 3, use_exclusive_lock_, sparse, &mom));
4035
4036 OP_REQUIRES(
4037 ctx, var.IsInitialized(),
4038 errors::FailedPrecondition(
4039 "Attempting to use uninitialized variables: ", requested_input(0)));
4040 OP_REQUIRES(
4041 ctx, mg.IsInitialized(),
4042 errors::FailedPrecondition(
4043 "Attempting to use uninitialized variables: ", requested_input(1)));
4044 OP_REQUIRES(
4045 ctx, ms.IsInitialized(),
4046 errors::FailedPrecondition(
4047 "Attempting to use uninitialized variables: ", requested_input(2)));
4048 OP_REQUIRES(
4049 ctx, mom.IsInitialized(),
4050 errors::FailedPrecondition(
4051 "Attempting to use uninitialized variables: ", requested_input(3)));
4052
4053 const Tensor& lr = ctx->input(4);
4054 const Tensor& rho = ctx->input(5);
4055 const Tensor& momentum = ctx->input(6);
4056 const Tensor& epsilon = ctx->input(7);
4057 const Tensor& grad = ctx->input(8);
4058
4059 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
4060 errors::InvalidArgument("lr is not a scalar : ",
4061 lr.shape().DebugString()));
4062 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
4063 errors::InvalidArgument("rho is not a scalar: ",
4064 rho.shape().DebugString()));
4065 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
4066 errors::InvalidArgument("momentum is not a scalar: ",
4067 momentum.shape().DebugString()));
4068 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
4069 errors::InvalidArgument("epsilon is not a scalar: ",
4070 epsilon.shape().DebugString()));
4071
4072 OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()),
4073 errors::InvalidArgument("var and mg do not have the same shape",
4074 var.shape().DebugString(), " ",
4075 ms.shape().DebugString()));
4076
4077 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
4078 errors::InvalidArgument("var and ms do not have the same shape",
4079 var.shape().DebugString(), " ",
4080 ms.shape().DebugString()));
4081
4082 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
4083 errors::InvalidArgument(
4084 "var and mom do not have the same shape",
4085 var.shape().DebugString(), " ", mom.shape().DebugString()));
4086
4087 OP_REQUIRES(
4088 ctx, var.shape().IsSameSize(grad.shape()),
4089 errors::InvalidArgument("var and grad do not have the same shape",
4090 var.shape().DebugString(), " ",
4091 grad.shape().DebugString()));
4092
4093 const Device& device = ctx->template eigen_device<Device>();
4094 functor::ApplyCenteredRMSProp<Device, T>()(
4095 device, var.flat<T>(), mg.flat<T>(), ms.flat<T>(), mom.flat<T>(),
4096 lr.scalar<T>(), rho.scalar<T>(), momentum.scalar<T>(),
4097 epsilon.scalar<T>(), grad.flat<T>());
4098 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
4099 }
4100
4101 private:
4102 bool use_exclusive_lock_;
4103};
4104
4105#define REGISTER_KERNELS(D, T) \
4106 REGISTER_KERNEL_BUILDER( \
4107 Name("ApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
4108 ApplyRMSPropOp<D##Device, T>); \
4109 REGISTER_KERNEL_BUILDER( \
4110 Name("ApplyCenteredRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
4111 ApplyCenteredRMSPropOp<D##Device, T>); \
4112 REGISTER_KERNEL_BUILDER(Name("ResourceApplyRMSProp") \
4113 .Device(DEVICE_##D) \
4114 .HostMemory("var") \
4115 .HostMemory("ms") \
4116 .HostMemory("mom") \
4117 .TypeConstraint<T>("T"), \
4118 ApplyRMSPropOp<D##Device, T>); \
4119 REGISTER_KERNEL_BUILDER(Name("ResourceApplyCenteredRMSProp") \
4120 .Device(DEVICE_##D) \
4121 .HostMemory("var") \
4122 .HostMemory("mg") \
4123 .HostMemory("ms") \
4124 .HostMemory("mom") \
4125 .TypeConstraint<T>("T"), \
4126 ApplyCenteredRMSPropOp<D##Device, T>);
4127#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
4128
4129TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS);
4130TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
4131
4132#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
4133// Forward declarations of the functor specializations for GPU.
4134namespace functor {
4135#define DECLARE_GPU_SPEC(T) \
4136 template <> \
4137 void ApplyRMSProp<GPUDevice, T>::operator()( \
4138 const GPUDevice& d, typename TTypes<T>::Flat var, \
4139 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, \
4140 typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \
4141 typename TTypes<T>::ConstScalar momentum, \
4142 typename TTypes<T>::ConstScalar epsilon, \
4143 typename TTypes<T>::ConstFlat grad); \
4144 extern template struct ApplyRMSProp<GPUDevice, T>; \
4145 template <> \
4146 void ApplyCenteredRMSProp<GPUDevice, T>::operator()( \
4147 const GPUDevice& d, typename TTypes<T>::Flat var, \
4148 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, \
4149 typename TTypes<T>::Flat mom, typename TTypes<T>::ConstScalar lr, \
4150 typename TTypes<T>::ConstScalar rho, \
4151 typename TTypes<T>::ConstScalar momentum, \
4152 typename TTypes<T>::ConstScalar epsilon, \
4153 typename TTypes<T>::ConstFlat grad); \
4154 extern template struct ApplyCenteredRMSProp<GPUDevice, T>;
4155DECLARE_GPU_SPEC(Eigen::half);
4156DECLARE_GPU_SPEC(float);
4157DECLARE_GPU_SPEC(double);
4158DECLARE_GPU_SPEC(complex64);
4159DECLARE_GPU_SPEC(complex128);
4160#undef DECLARE_GPU_SPEC
4161} // namespace functor
4162
4163REGISTER_KERNELS(GPU, Eigen::half);
4164REGISTER_KERNELS(GPU, float);
4165REGISTER_KERNELS(GPU, double);
4166REGISTER_KERNELS(GPU, complex64);
4167REGISTER_KERNELS(GPU, complex128);
4168#endif
4169#undef REGISTER_CPU_KERNELS
4170#undef REGISTER_KERNELS
4171
4172// Note, this op works on cpu only.
4173template <typename T, typename Tindex>
4174class SparseApplyRMSPropOp : public OpKernel {
4175 public:
4176 explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
4177 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
4178 }
4179
4180 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
4181 const bool sparse = true;
4182 auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
4183 ctx, use_exclusive_lock_, sparse, {0, 1, 2});
4184
4185 Tensor var;
4186 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
4187 ctx, 0, use_exclusive_lock_, sparse, &var));
4188 Tensor ms;
4189 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
4190 ctx, 1, use_exclusive_lock_, sparse, &ms));
4191 Tensor mom;
4192 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
4193 ctx, 2, use_exclusive_lock_, sparse, &mom));
4194
4195 OP_REQUIRES(
4196 ctx, var.IsInitialized(),
4197 errors::FailedPrecondition(
4198 "Attempting to use uninitialized variables: ", requested_input(0)));
4199 OP_REQUIRES(
4200 ctx, ms.IsInitialized(),
4201 errors::FailedPrecondition(
4202 "Attempting to use uninitialized variables: ", requested_input(1)));
4203 OP_REQUIRES(
4204 ctx, mom.IsInitialized(),
4205 errors::FailedPrecondition(
4206 "Attempting to use uninitialized variables: ", requested_input(2)));
4207
4208 const Tensor& lr = ctx->input(3);
4209 const Tensor& rho = ctx->input(4);
4210 const Tensor& momentum = ctx->input(5);
4211 const Tensor& epsilon = ctx->input(6);
4212 const Tensor& grad = ctx->input(7);
4213 const Tensor& indices = ctx->input(8);
4214
4215 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
4216 errors::InvalidArgument("lr is not a scalar: ",
4217 lr.shape().DebugString()));
4218 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
4219 errors::InvalidArgument("rho is not a scalar: ",
4220 rho.shape().DebugString()));
4221 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
4222 errors::InvalidArgument("momentum is not a scalar: ",
4223 momentum.shape().DebugString()));
4224 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
4225 errors::InvalidArgument("epsilon is not a scalar: ",
4226 epsilon.shape().DebugString()));
4227
4228 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
4229 errors::InvalidArgument("var and ms do not have the same shape",
4230 var.shape().DebugString(), " ",
4231 ms.shape().DebugString()));
4232
4233 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
4234 errors::InvalidArgument(
4235 "var and mom do not have the same shape",
4236 var.shape().DebugString(), " ", mom.shape().DebugString()));
4237
4238 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
4239 errors::InvalidArgument("var must be at least 1 dimensional"));
4240
4241 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
4242 errors::InvalidArgument("indices must be one-dimensional"));
4243
4244 for (int d = 1; d < var.dims(); d++) {
4245 OP_REQUIRES(
4246 ctx, var.dim_size(d) == grad.dim_size(d),
4247 errors::InvalidArgument("var and grad must match in dimension ", d));
4248 }
4249 const Tindex N = indices.dim_size(0);
4250 OP_REQUIRES(
4251 ctx, grad.dim_size(0) == N,
4252 errors::InvalidArgument(
4253 "grad must be the same size as indices in the first dimension."));
4254
4255 if (N > 0) {
4256 const Tindex first_dim_size = var.dim_size(0);
4257 // Validate all the indices are in range
4258 auto indices_vec = indices.vec<Tindex>();
4259 for (Tindex i = 0; i < N; i++) {
4260 const Tindex index = indices_vec(i);
4261 OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
4262 errors::InvalidArgument(
4263 strings::StrCat("Index ", index, " at offset ", i,
4264 " in indices is out of range")));
4265 }
4266
4267 auto var_flat = var.flat_outer_dims<T>();
4268 auto ms_flat = ms.flat_outer_dims<T>();
4269 auto mom_flat = mom.flat_outer_dims<T>();
4270 auto grad_flat = grad.flat_outer_dims<T>();
4271 const T lr_scalar = lr.scalar<T>()();
4272 const T rho_scalar = rho.scalar<T>()();
4273 const T epsilon_scalar = epsilon.scalar<T>()();
4274 const T momentum_scalar = momentum.scalar<T>()();
4275
4276 for (Tindex i = 0; i < N; i++) {
4277 const Tindex index = indices_vec(i);
4278
4279 auto ms_ = ms_flat.template chip<0>(index);
4280 auto mom_ = mom_flat.template chip<0>(index);
4281 auto grad_ = grad_flat.template chip<0>(i);
4282
4283 ms_ = ms_ * ms_.constant(rho_scalar) +
4284 grad_.square() * grad_.constant(T(1) - rho_scalar);
4285 mom_ = mom_ * mom_.constant(momentum_scalar) +
4286 (ms_ + ms_.constant(epsilon_scalar)).rsqrt() *
4287 ms_.constant(lr_scalar) * grad_;
4288
4289 auto v = var_flat.template chip<0>(index);
4290 v -= mom_;
4291 }
4292 }
4293
4294 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
4295 }
4296
4297 private:
4298 bool use_exclusive_lock_;
4299};
4300
4301// Note, this op works on cpu only.
4302template <typename T, typename Tindex>
4303class SparseApplyCenteredRMSPropOp : public OpKernel {
4304 public:
4305 explicit SparseApplyCenteredRMSPropOp(OpKernelConstruction* ctx)
4306 : OpKernel(ctx) {
4307 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
4308 }
4309
4310 void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS {
4311 const bool sparse = true;
4312 auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>(
4313 ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3});
4314
4315 Tensor var;
4316 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
4317 ctx, 0, use_exclusive_lock_, sparse, &var));
4318 Tensor mg;
4319 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
4320 ctx, 1, use_exclusive_lock_, sparse, &mg));
4321 Tensor ms;
4322 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
4323 ctx, 2, use_exclusive_lock_, sparse, &ms));
4324 Tensor mom;
4325 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
4326 ctx, 3, use_exclusive_lock_, sparse, &mom));
4327
4328 OP_REQUIRES(
4329 ctx, var.IsInitialized(),
4330 errors::FailedPrecondition(
4331 "Attempting to use uninitialized variables: ", requested_input(0)));
4332 OP_REQUIRES(
4333 ctx, ms.IsInitialized(),
4334 errors::FailedPrecondition(
4335 "Attempting to use uninitialized variables: ", requested_input(2)));
4336 OP_REQUIRES(
4337 ctx, mom.IsInitialized(),
4338 errors::FailedPrecondition(
4339 "Attempting to use uninitialized variables: ", requested_input(3)));
4340
4341 const Tensor& lr = ctx->input(4);
4342 const Tensor& rho = ctx->input(5);
4343 const Tensor& momentum = ctx->input(6);
4344 const Tensor& epsilon = ctx->input(7);
4345 const Tensor& grad = ctx->input(8);
4346 const Tensor& indices = ctx->input(9);
4347
4348 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
4349 errors::InvalidArgument("lr is not a scalar: ",
4350 lr.shape().DebugString()));
4351 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
4352 errors::InvalidArgument("rho is not a scalar: ",
4353 rho.shape().DebugString()));
4354 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
4355 errors::InvalidArgument("momentum is not a scalar: ",
4356 momentum.shape().DebugString()));
4357 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
4358 errors::InvalidArgument("epsilon is not a scalar: ",
4359 epsilon.shape().DebugString()));
4360
4361 OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()),
4362 errors::InvalidArgument("var and mg do not have the same shape",
4363 var.shape().DebugString(), " ",
4364 mg.shape().DebugString()));
4365
4366 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
4367 errors::InvalidArgument("var and ms do not have the same shape",
4368 var.shape().DebugString(), " ",
4369 ms.shape().DebugString()));
4370
4371 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
4372 errors::InvalidArgument(
4373 "var and mom do not have the same shape",
4374 var.shape().DebugString(), " ", mom.shape().DebugString()));
4375
4376 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
4377 errors::InvalidArgument("var must be at least 1 dimensional"));
4378
4379 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
4380 errors::InvalidArgument("indices must be one-dimensional"));
4381
4382 for (int d = 1; d < var.dims(); d++) {
4383 OP_REQUIRES(
4384 ctx, var.dim_size(d) == grad.dim_size(d),
4385 errors::InvalidArgument("var and grad must match in dimension ", d));
4386 }
4387 const Tindex N = indices.dim_size(0);
4388 OP_REQUIRES(
4389 ctx, grad.dim_size(0) == N,
4390 errors::InvalidArgument(
4391 "grad must be the same size as indices in the first dimension."));
4392
4393 if (N > 0) {
4394 const Tindex first_dim_size = var.dim_size(0);
4395 // Validate all the indices are in range
4396 auto indices_vec = indices.vec<Tindex>();
4397 for (Tindex i = 0; i < N; i++) {
4398 const Tindex index = indices_vec(i);
4399 OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
4400 errors::InvalidArgument(
4401 strings::StrCat("Index ", index, " at offset ", i,
4402 " in indices is out of range")));
4403 }
4404
4405 auto var_flat = var.flat_outer_dims<T>();
4406 auto ms_flat = ms.flat_outer_dims<T>();
4407 auto mg_flat = mg.flat_outer_dims<T>();
4408 auto mom_flat = mom.flat_outer_dims<T>();
4409 auto grad_flat = grad.flat_outer_dims<T>();
4410 const T lr_scalar = lr.scalar<T>()();
4411 const T rho_scalar = rho.scalar<T>()();
4412 const T epsilon_scalar = epsilon.scalar<T>()();
4413 const T momentum_scalar = momentum.scalar<T>()();
4414
4415 for (Tindex i = 0; i < N; i++) {
4416 const Tindex index = indices_vec(i);
4417
4418 auto ms_ = ms_flat.template chip<0>(index);
4419 auto mom_ = mom_flat.template chip<0>(index);
4420 auto grad_ = grad_flat.template chip<0>(i);
4421
4422 ms_ = ms_ * ms_.constant(rho_scalar) +
4423 grad_.square() * grad_.constant(T(1) - rho_scalar);
4424
4425 auto mg_ = mg_flat.template chip<0>(index);
4426 mg_ = mg_ * mg_.constant(rho_scalar) +
4427 grad_ * grad_.constant(T(1) - rho_scalar);
4428 auto denom_ = ms_ + ms_.constant(epsilon_scalar) - mg_.square();
4429 mom_ = mom_ * mom_.constant(momentum_scalar) +
4430 denom_.rsqrt() * ms_.constant(lr_scalar) * grad_;
4431 auto v = var_flat.template chip<0>(index);
4432 v -= mom_;
4433 }
4434 }
4435
4436 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
4437 }
4438
4439 private:
4440 bool use_exclusive_lock_;
4441};
4442
4443#define REGISTER_KERNELS(T, Tindices) \
4444 REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
4445 .Device(DEVICE_CPU) \
4446 .TypeConstraint<T>("T") \
4447 .TypeConstraint<Tindices>("Tindices"), \
4448 SparseApplyRMSPropOp<T, Tindices>); \
4449 REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \
4450 .Device(DEVICE_CPU) \
4451 .TypeConstraint<T>("T") \
4452 .TypeConstraint<Tindices>("Tindices"), \
4453 SparseApplyCenteredRMSPropOp<T, Tindices>); \
4454 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyRMSProp") \
4455 .Device(DEVICE_CPU) \
4456 .TypeConstraint<T>("T") \
4457 .TypeConstraint<Tindices>("Tindices"), \
4458 SparseApplyRMSPropOp<T, Tindices>); \
4459 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyCenteredRMSProp") \
4460 .Device(DEVICE_CPU) \
4461 .TypeConstraint<T>("T") \
4462 .TypeConstraint<Tindices>("Tindices"), \
4463 SparseApplyCenteredRMSPropOp<T, Tindices>);
4464
4465REGISTER_KERNELS(Eigen::half, int32);
4466REGISTER_KERNELS(Eigen::half, int64_t);
4467REGISTER_KERNELS(float, int32);
4468REGISTER_KERNELS(float, int64_t);
4469REGISTER_KERNELS(double, int32);
4470REGISTER_KERNELS(double, int64_t);
4471REGISTER_KERNELS(complex64, int32);
4472REGISTER_KERNELS(complex64, int64_t);
4473REGISTER_KERNELS(complex128, int32);
4474REGISTER_KERNELS(complex128, int64_t);
4475
4476#undef REGISTER_KERNELS
4477
4478template <typename Device, typename T>
4479class ApplyAddSignOp : public OpKernel {
4480 public:
4481 explicit ApplyAddSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
4482 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
4483 }
4484
4485 void Compute(OpKernelContext* ctx) override {
4486 const bool sparse = false;
4487 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
4488 ctx, use_exclusive_lock_, sparse, {0, 1});
4489
4490 Tensor var;
4491 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4492 ctx, 0, use_exclusive_lock_, sparse, &var));
4493 Tensor m;
4494 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4495 ctx, 1, use_exclusive_lock_, sparse, &m));
4496 OP_REQUIRES(
4497 ctx, var.IsInitialized(),
4498 errors::FailedPrecondition(
4499 "Attempting to use uninitialized variables: ", requested_input(0)));
4500 OP_REQUIRES(
4501 ctx, m.IsInitialized(),
4502 errors::FailedPrecondition(
4503 "Attempting to use uninitialized variables: ", requested_input(1)));
4504 const Tensor& lr = ctx->input(2);
4505 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
4506 errors::InvalidArgument("lr is not a scalar: ",
4507 lr.shape().DebugString()));
4508 const Tensor& alpha = ctx->input(3);
4509 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
4510 errors::InvalidArgument("alpha is not a scalar: ",
4511 alpha.shape().DebugString()));
4512 const Tensor& sign_decay = ctx->input(4);
4513 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
4514 errors::InvalidArgument("sign_decay is not a scalar: ",
4515 sign_decay.shape().DebugString()));
4516 const Tensor& beta = ctx->input(5);
4517 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()),
4518 errors::InvalidArgument("beta is not a scalar: ",
4519 beta.shape().DebugString()));
4520 const Tensor& grad = ctx->input(6);
4521 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
4522 errors::InvalidArgument("var and m do not have the same shape",
4523 var.shape().DebugString(), " ",
4524 m.shape().DebugString()));
4525 OP_REQUIRES(
4526 ctx, var.shape().IsSameSize(grad.shape()),
4527 errors::InvalidArgument("var and grad do not have the same shape",
4528 var.shape().DebugString(), " ",
4529 grad.shape().DebugString()));
4530
4531 const Device& device = ctx->template eigen_device<Device>();
4532 functor::ApplyAddSign<Device, T>()(
4533 device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), alpha.scalar<T>(),
4534 sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>());
4535 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
4536 }
4537
4538 private:
4539 bool use_exclusive_lock_;
4540};
4541
4542#define REGISTER_KERNELS(D, T) \
4543 REGISTER_KERNEL_BUILDER( \
4544 Name("ApplyAddSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \
4545 ApplyAddSignOp<D##Device, T>); \
4546 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAddSign") \
4547 .Device(DEVICE_##D) \
4548 .HostMemory("var") \
4549 .HostMemory("m") \
4550 .TypeConstraint<T>("T"), \
4551 ApplyAddSignOp<D##Device, T>);
4552#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
4553
4554TF_CALL_half(REGISTER_CPU_KERNELS);
4555TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
4556TF_CALL_float(REGISTER_CPU_KERNELS);
4557TF_CALL_double(REGISTER_CPU_KERNELS);
4558
4559#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
4560// Forward declarations of the functor specializations for GPU.
4561namespace functor {
4562#define DECLARE_GPU_SPEC(T) \
4563 template <> \
4564 void ApplyAddSign<GPUDevice, T>::operator()( \
4565 const GPUDevice& d, typename TTypes<T>::Flat var, \
4566 typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \
4567 typename TTypes<T>::ConstScalar alpha, \
4568 typename TTypes<T>::ConstScalar sign_decay, \
4569 typename TTypes<T>::ConstScalar beta, \
4570 typename TTypes<T>::ConstFlat grad); \
4571 extern template struct ApplyAddSign<GPUDevice, T>;
4572DECLARE_GPU_SPEC(Eigen::half);
4573DECLARE_GPU_SPEC(float);
4574DECLARE_GPU_SPEC(double);
4575#undef DECLARE_GPU_SPEC
4576} // namespace functor
4577
4578REGISTER_KERNELS(GPU, Eigen::half);
4579REGISTER_KERNELS(GPU, float);
4580REGISTER_KERNELS(GPU, double);
4581#endif
4582#undef REGISTER_CPU_KERNELS
4583#undef REGISTER_KERNELS
4584
4585template <typename Device, typename T>
4586class ApplyPowerSignOp : public OpKernel {
4587 public:
4588 explicit ApplyPowerSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
4589 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
4590 }
4591
4592 void Compute(OpKernelContext* ctx) override {
4593 const bool sparse = false;
4594 auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>(
4595 ctx, use_exclusive_lock_, sparse, {0, 1});
4596
4597 Tensor var;
4598 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4599 ctx, 0, use_exclusive_lock_, sparse, &var));
4600 Tensor m;
4601 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
4602 ctx, 1, use_exclusive_lock_, sparse, &m));
4603 OP_REQUIRES(
4604 ctx, var.IsInitialized(),
4605 errors::FailedPrecondition(
4606 "Attempting to use uninitialized variables: ", requested_input(0)));
4607 OP_REQUIRES(
4608 ctx, m.IsInitialized(),
4609 errors::FailedPrecondition(
4610 "Attempting to use uninitialized variables: ", requested_input(1)));
4611 const Tensor& lr = ctx->input(2);
4612 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
4613 errors::InvalidArgument("lr is not a scalar: ",
4614 lr.shape().DebugString()));
4615 const Tensor& logbase = ctx->input(3);
4616 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()),
4617 errors::InvalidArgument("logbase is not a scalar: ",
4618 logbase.shape().DebugString()));
4619 const Tensor& sign_decay = ctx->input(4);
4620 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()),
4621 errors::InvalidArgument("sign_decay is not a scalar: ",
4622 sign_decay.shape().DebugString()));
4623 const Tensor& beta = ctx->input(5);
4624 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()),
4625 errors::InvalidArgument("beta is not a scalar: ",
4626 beta.shape().DebugString()));
4627 const Tensor& grad = ctx->input(6);
4628 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
4629 errors::InvalidArgument("var and m do not have the same shape",
4630 var.shape().DebugString(), " ",
4631 m.shape().DebugString()));
4632 OP_REQUIRES(
4633 ctx, var.shape().IsSameSize(grad.shape()),
4634 errors::InvalidArgument("var and grad do not have the same shape",
4635 var.shape().DebugString(), " ",
4636 grad.shape().DebugString()));
4637
4638 const Device& device = ctx->template eigen_device<Device>();
4639 functor::ApplyPowerSign<Device, T>()(
4640 device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), logbase.scalar<T>(),
4641 sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>());
4642 MaybeForwardRefInputToRefOutput(ctx, 0, 0);
4643 }
4644
4645 private:
4646 bool use_exclusive_lock_;
4647};
4648
4649#define REGISTER_KERNELS(D, T) \
4650 REGISTER_KERNEL_BUILDER( \
4651 Name("ApplyPowerSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \
4652 ApplyPowerSignOp<D##Device, T>); \
4653 REGISTER_KERNEL_BUILDER(Name("ResourceApplyPowerSign") \
4654 .Device(DEVICE_##D) \
4655 .HostMemory("var") \
4656 .HostMemory("m") \
4657 .TypeConstraint<T>("T"), \
4658 ApplyPowerSignOp<D##Device, T>);
4659#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
4660
4661TF_CALL_half(REGISTER_CPU_KERNELS);
4662TF_CALL_bfloat16(REGISTER_CPU_KERNELS);
4663TF_CALL_float(REGISTER_CPU_KERNELS);
4664TF_CALL_double(REGISTER_CPU_KERNELS);
4665
4666#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
4667// Forward declarations of the functor specializations for GPU.
4668namespace functor {
4669#define DECLARE_GPU_SPEC(T) \
4670 template <> \
4671 void ApplyPowerSign<GPUDevice, T>::operator()( \
4672 const GPUDevice& d, typename TTypes<T>::Flat var, \
4673 typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \
4674 typename TTypes<T>::ConstScalar logbase, \
4675 typename TTypes<T>::ConstScalar sign_decay, \
4676 typename TTypes<T>::ConstScalar beta, \
4677 typename TTypes<T>::ConstFlat grad); \
4678 extern template struct ApplyPowerSign<GPUDevice, T>;
4679DECLARE_GPU_SPEC(Eigen::half);
4680DECLARE_GPU_SPEC(float);
4681DECLARE_GPU_SPEC(double);
4682#undef DECLARE_GPU_SPEC
4683} // namespace functor
4684
4685REGISTER_KERNELS(GPU, Eigen::half);
4686REGISTER_KERNELS(GPU, float);
4687REGISTER_KERNELS(GPU, double);
4688#endif
4689#undef REGISTER_CPU_KERNELS
4690#undef REGISTER_KERNELS
4691
4692} // namespace tensorflow
4693