1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | #include "tensorflow/core/kernels/training_ops.h" |
18 | |
19 | #include <algorithm> // NOLINT |
20 | |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/kernels/training_op_helpers.h" |
25 | #include "tensorflow/core/kernels/variable_ops.h" |
26 | #include "tensorflow/core/lib/core/errors.h" |
27 | #include "tensorflow/core/platform/bfloat16.h" |
28 | #include "tensorflow/core/util/util.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | using CPUDevice = Eigen::ThreadPoolDevice; |
33 | using GPUDevice = Eigen::GpuDevice; |
34 | using Index = Eigen::Index; |
35 | |
36 | namespace { |
37 | template <class T> |
38 | inline T sgn(const T x) { |
39 | T zero(0); |
40 | T one(1); |
41 | return (x == zero ? zero : (x < zero ? -one : one)); |
42 | } |
43 | } // namespace |
44 | |
45 | namespace functor { |
46 | template <typename T> |
47 | struct ApplyGradientDescent<CPUDevice, T> { |
48 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
49 | typename TTypes<T>::ConstScalar lr, |
50 | typename TTypes<T>::ConstFlat grad) { |
51 | var.device(d) -= grad * lr(); |
52 | } |
53 | }; |
54 | |
55 | template <typename T> |
56 | struct ApplyAdadelta<CPUDevice, T> { |
57 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
58 | typename TTypes<T>::Flat accum, |
59 | typename TTypes<T>::Flat accum_update, |
60 | typename TTypes<T>::ConstScalar lr, |
61 | typename TTypes<T>::ConstScalar rho, |
62 | typename TTypes<T>::ConstScalar epsilon, |
63 | typename TTypes<T>::ConstFlat grad) { |
64 | accum.device(d) = |
65 | accum * rho() + grad.square() * (static_cast<T>(1) - rho()); |
66 | const auto update = |
67 | (accum_update + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad; |
68 | var.device(d) -= update * lr(); |
69 | accum_update.device(d) = |
70 | accum_update * rho() + update.square() * (static_cast<T>(1) - rho()); |
71 | } |
72 | }; |
73 | |
74 | template <typename T> |
75 | struct ApplyProximalGradientDescent<CPUDevice, T> { |
76 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
77 | typename TTypes<T>::ConstScalar lr, |
78 | typename TTypes<T>::ConstScalar l1, |
79 | typename TTypes<T>::ConstScalar l2, |
80 | typename TTypes<T>::ConstFlat grad) { |
81 | // Note that here is Fobos update, for details please refer: |
82 | // http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf |
83 | // TODO(xbing): merge the logic for ProximalGradientDescent and |
84 | // ProximalAdagrad. |
85 | auto prox_var = var; |
86 | // compute v = w - lr * grad. |
87 | prox_var.device(d) -= grad * lr(); |
88 | if (l1() > 0) { |
89 | // compute sign(v) * max(|v| - lr * l1, 0) |
90 | var.device(d) = |
91 | prox_var.sign() * |
92 | (prox_var.abs() - var.constant(lr() * l1())).cwiseMax(T(0.0)) / |
93 | (var.constant(1.0) + var.constant(l2() * lr())); |
94 | } else { |
95 | var.device(d) = |
96 | prox_var / (var.constant(1.0) + var.constant(l2() * lr())); |
97 | } |
98 | } |
99 | }; |
100 | |
101 | template <typename T> |
102 | struct ApplyAdagradDA<CPUDevice, T> { |
103 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
104 | typename TTypes<T>::Flat gradient_accum, |
105 | typename TTypes<T>::Flat gradient_squared_accum, |
106 | typename TTypes<T>::ConstScalar lr, int64_t global_step, |
107 | typename TTypes<T>::ConstScalar l1, |
108 | typename TTypes<T>::ConstScalar l2, |
109 | typename TTypes<T>::ConstFlat grad) { |
110 | // Accumulate gradient, and gradient_squared |
111 | gradient_accum.device(d) += grad; |
112 | gradient_squared_accum.device(d) += grad.square(); |
113 | |
114 | // AdagradDA update: |
115 | // Let g to be gradient accumulator, gg to be gradient squared accumulator, |
116 | // T be the global step, lr is the learning rate, and k the initial |
117 | // gradient squared accumulator value. |
118 | // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})} |
119 | if (l1() > 0) { |
120 | var.device(d) = |
121 | lr() * var.constant(-1.0) * gradient_accum.sign() * |
122 | (gradient_accum.abs() - |
123 | var.constant(static_cast<float>(global_step)) * var.constant(l1())) |
124 | .cwiseMax(T(0.0)) / |
125 | (var.constant(l2()) * |
126 | var.constant(static_cast<float>(global_step) * lr()) + |
127 | gradient_squared_accum.sqrt()); |
128 | } else { |
129 | var.device(d) = |
130 | lr() * gradient_accum * var.constant(-1.0) / |
131 | (var.constant(l2()) * |
132 | var.constant(static_cast<float>(global_step) * lr()) + |
133 | gradient_squared_accum.sqrt()); |
134 | } |
135 | } |
136 | }; |
137 | |
138 | template <typename T> |
139 | struct ApplyAdagrad<CPUDevice, T> { |
140 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
141 | typename TTypes<T>::Flat accum, |
142 | typename TTypes<T>::ConstScalar lr, |
143 | typename TTypes<T>::ConstFlat grad, bool update_slots) { |
144 | if (update_slots) { |
145 | accum.device(d) += grad.square(); |
146 | } |
147 | var.device(d) -= grad * lr() * accum.rsqrt(); |
148 | } |
149 | }; |
150 | |
151 | template <typename T> |
152 | struct ApplyAdagradV2<CPUDevice, T> { |
153 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
154 | typename TTypes<T>::Flat accum, |
155 | typename TTypes<T>::ConstScalar lr, |
156 | typename TTypes<T>::ConstScalar epsilon, |
157 | typename TTypes<T>::ConstFlat grad, bool update_slots) { |
158 | if (update_slots) { |
159 | accum.device(d) += grad.square(); |
160 | } |
161 | var.device(d) -= grad * lr() / (accum.sqrt() + epsilon()); |
162 | } |
163 | }; |
164 | |
165 | template <typename T, typename Tindex, bool has_epsilon> |
166 | struct SparseApplyAdagrad<CPUDevice, T, Tindex, has_epsilon> { |
167 | Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var, |
168 | typename TTypes<T>::Matrix accum, |
169 | typename TTypes<T>::ConstScalar lr, |
170 | typename TTypes<T>::ConstScalar epsilon, |
171 | typename TTypes<T>::ConstMatrix grad, |
172 | typename TTypes<Tindex>::ConstVec indices, |
173 | int64_t inner_dim, bool update_slots) { |
174 | const Tindex N = static_cast<Tindex>(indices.dimension(0)); |
175 | if (N == 0) return OkStatus(); |
176 | const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0)); |
177 | const T lr_scalar = lr(); |
178 | const int in_bytes = inner_dim * sizeof(T) * 3; |
179 | const int out_bytes = inner_dim * sizeof(T) * 2; |
180 | const int cycles = inner_dim * (Eigen::TensorOpCost::AddCost<T>() * 2 + |
181 | Eigen::TensorOpCost::MulCost<T>() * 2); |
182 | const Eigen::TensorOpCost cost(in_bytes, out_bytes, cycles); |
183 | |
184 | if (inner_dim > 1) { |
185 | for (Tindex i = 0; i < N; ++i) { |
186 | const Tindex index = internal::SubtleMustCopy(indices(i)); |
187 | if (!FastBoundsCheck(index, first_dim_size)) { |
188 | return errors::InvalidArgument( |
189 | strings::StrCat("Index " , index, " at offset " , i, |
190 | " in indices is out of range" )); |
191 | } |
192 | } |
193 | |
194 | const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { |
195 | for (Tindex i = start_idx; i < end_idx; ++i) { |
196 | const Tindex index = internal::SubtleMustCopy(indices(i)); |
197 | auto a = accum.template chip<0>(index); |
198 | auto g = grad.template chip<0>(i); |
199 | auto v = var.template chip<0>(index); |
200 | if (update_slots) { |
201 | a += g.square(); |
202 | } |
203 | if (has_epsilon) { |
204 | v -= g.constant(lr_scalar) * g / (a.sqrt() + a.constant(epsilon())); |
205 | } else { |
206 | v -= g.constant(lr_scalar) * g * a.rsqrt(); |
207 | } |
208 | } |
209 | }; |
210 | |
211 | d.parallelFor(N, cost, shard); |
212 | } else { |
213 | for (Tindex i = 0; i < N; ++i) { |
214 | const Tindex index = internal::SubtleMustCopy(indices(i)); |
215 | if (!FastBoundsCheck(index, first_dim_size)) { |
216 | return errors::InvalidArgument( |
217 | strings::StrCat("Index " , index, " at offset " , i, |
218 | " in indices is out of range" )); |
219 | } |
220 | } |
221 | |
222 | const auto shard = [&](Tindex start_idx, Tindex end_idx) -> void { |
223 | for (Tindex i = start_idx; i < end_idx; ++i) { |
224 | const Tindex index = internal::SubtleMustCopy(indices(i)); |
225 | T& a = accum(index); |
226 | const T& g = grad(i); |
227 | if (update_slots) { |
228 | a += g * g; |
229 | } |
230 | if (has_epsilon) { |
231 | var(index) -= lr_scalar * g / (Eigen::numext::sqrt(a) + epsilon()); |
232 | } else { |
233 | var(index) -= lr_scalar * g / Eigen::numext::sqrt(a); |
234 | } |
235 | } |
236 | }; |
237 | |
238 | d.parallelFor(N, cost, shard); |
239 | } |
240 | |
241 | return OkStatus(); |
242 | } |
243 | }; |
244 | |
245 | template <typename T> |
246 | struct ApplyProximalAdagrad<CPUDevice, T> { |
247 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
248 | typename TTypes<T>::Flat accum, |
249 | typename TTypes<T>::ConstScalar lr, |
250 | typename TTypes<T>::ConstScalar l1, |
251 | typename TTypes<T>::ConstScalar l2, |
252 | typename TTypes<T>::ConstFlat grad) { |
253 | // Fobos update per paper with Adagrad learning rate. |
254 | accum.device(d) += grad.square(); |
255 | // Adagrad learning rate. |
256 | auto learning_rate = accum.constant(lr()) * accum.rsqrt(); |
257 | auto prox_var = var; |
258 | // compute v = w - lr * grad. |
259 | prox_var.device(d) -= grad * learning_rate; |
260 | if (l1() > 0) { |
261 | // compute sign(v) * max(|v| - lr * l1, 0) |
262 | var.device(d) = prox_var.sign() * |
263 | (prox_var.abs() - learning_rate * prox_var.constant(l1())) |
264 | .cwiseMax(T(0.0)) / |
265 | (var.constant(1.0) + var.constant(l2()) * learning_rate); |
266 | } else { |
267 | var.device(d) = |
268 | prox_var / (var.constant(1.0) + var.constant(l2()) * learning_rate); |
269 | } |
270 | } |
271 | }; |
272 | |
273 | template <typename T, typename Tindex> |
274 | struct SparseApplyProximalAdagrad<CPUDevice, T, Tindex> { |
275 | Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var, |
276 | typename TTypes<T>::Matrix accum, |
277 | typename TTypes<T>::ConstScalar lr, |
278 | typename TTypes<T>::ConstScalar l1, |
279 | typename TTypes<T>::ConstScalar l2, |
280 | typename TTypes<T>::ConstMatrix grad, |
281 | typename TTypes<Tindex>::ConstVec indices, |
282 | int64_t inner_dim) { |
283 | const Tindex N = static_cast<Tindex>(indices.dimension(0)); |
284 | if (N == 0) return OkStatus(); |
285 | const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0)); |
286 | const T lr_scalar = lr(); |
287 | const T l1_scalar = l1(); |
288 | const T l2_scalar = l2(); |
289 | if (inner_dim > 1) { |
290 | for (Tindex i = 0; i < N; i++) { |
291 | const Tindex index = internal::SubtleMustCopy(indices(i)); |
292 | if (!FastBoundsCheck(index, first_dim_size)) { |
293 | return errors::InvalidArgument( |
294 | strings::StrCat("Index " , index, " at offset " , i, |
295 | " in indices is out of range" )); |
296 | } |
297 | auto a = accum.template chip<0>(index); |
298 | auto g = grad.template chip<0>(i); |
299 | auto v = var.template chip<0>(index); |
300 | a += g.square(); |
301 | // compute learning_rate for current step. |
302 | auto learning_rate = a.constant(lr_scalar) * a.rsqrt(); |
303 | auto prox_v = v; |
304 | // v = w - g * learning_rate. |
305 | prox_v -= g * learning_rate; |
306 | if (l1_scalar > 0) { |
307 | // compute sign(v) * max(|v|, 0) |
308 | v = prox_v.sign() * |
309 | (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) |
310 | .cwiseMax(static_cast<T>(0.0)) / |
311 | (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); |
312 | } else { |
313 | v = prox_v / |
314 | (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); |
315 | } |
316 | } |
317 | } else { |
318 | for (Tindex i = 0; i < N; i++) { |
319 | const Tindex index = internal::SubtleMustCopy(indices(i)); |
320 | if (!FastBoundsCheck(index, first_dim_size)) { |
321 | return errors::InvalidArgument( |
322 | strings::StrCat("Index " , index, " at offset " , i, |
323 | " in indices is out of range" )); |
324 | } |
325 | T& a = accum(index); |
326 | const T& g = grad(i); |
327 | a += g * g; |
328 | auto learning_rate = lr_scalar / std::sqrt(a); |
329 | auto prox_v = var(index); |
330 | prox_v -= learning_rate * g; |
331 | if (l1_scalar > 0) { |
332 | var(index) = sgn(prox_v) * |
333 | std::max(std::abs(prox_v) - learning_rate * l1_scalar, |
334 | static_cast<T>(0.0)) / |
335 | (1.0 + l2_scalar * learning_rate); |
336 | } else { |
337 | var(index) = prox_v / (1.0 + l2_scalar * learning_rate); |
338 | } |
339 | } |
340 | } |
341 | return OkStatus(); |
342 | } |
343 | }; |
344 | |
345 | template <typename T> |
346 | struct ApplyFtrlV2<CPUDevice, T> { |
347 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
348 | typename TTypes<T>::Flat accum, |
349 | typename TTypes<T>::Flat linear, |
350 | typename TTypes<T>::ConstFlat grad, |
351 | typename TTypes<T>::ConstScalar lr, |
352 | typename TTypes<T>::ConstScalar l1, |
353 | typename TTypes<T>::ConstScalar l2, |
354 | typename TTypes<T>::ConstScalar l2_shrinkage, |
355 | typename TTypes<T>::ConstScalar lr_power) { |
356 | auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var; |
357 | auto new_accum = accum + grad * grad; |
358 | // special case for which lr_power=-0.5. |
359 | if (lr_power() == static_cast<T>(-0.5)) { |
360 | linear.device(d) += |
361 | grad_with_shrinkage - (new_accum.sqrt() - accum.sqrt()) / lr() * var; |
362 | } else { |
363 | linear.device(d) += |
364 | grad_with_shrinkage - |
365 | (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var; |
366 | } |
367 | auto x = (linear.constant(l1()) * linear.sign() - linear); |
368 | if (lr_power() == static_cast<T>(-0.5)) { |
369 | auto y = new_accum.sqrt() / new_accum.constant(lr()) + |
370 | linear.constant(static_cast<T>(2) * l2()); |
371 | auto pre_shrink = x / y; |
372 | var.device(d) = (linear.abs() > linear.constant(l1())) |
373 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
374 | |
375 | } else { |
376 | auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) + |
377 | linear.constant(static_cast<T>(2) * l2()); |
378 | auto pre_shrink = x / y; |
379 | var.device(d) = (linear.abs() > linear.constant(l1())) |
380 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
381 | } |
382 | accum.device(d) += grad * grad; |
383 | } |
384 | }; |
385 | |
386 | template <typename T> |
387 | struct ApplyFtrlV2MultiplyLinearByLr<CPUDevice, T> { |
388 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
389 | typename TTypes<T>::Flat accum, |
390 | typename TTypes<T>::Flat linear, |
391 | typename TTypes<T>::ConstFlat grad, |
392 | typename TTypes<T>::ConstScalar lr, |
393 | typename TTypes<T>::ConstScalar l1, |
394 | typename TTypes<T>::ConstScalar l2, |
395 | typename TTypes<T>::ConstScalar l2_shrinkage, |
396 | typename TTypes<T>::ConstScalar lr_power) { |
397 | auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var; |
398 | auto new_accum = accum + grad * grad; |
399 | // special case for which lr_power=-0.5. |
400 | if (lr_power() == static_cast<T>(-0.5)) { |
401 | linear.device(d) += |
402 | grad_with_shrinkage * lr() - (new_accum.sqrt() - accum.sqrt()) * var; |
403 | } else { |
404 | linear.device(d) += |
405 | grad_with_shrinkage * lr() - |
406 | (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) * var; |
407 | } |
408 | auto x = (linear.constant(l1() * lr()) * linear.sign() - linear); |
409 | if (lr_power() == static_cast<T>(-0.5)) { |
410 | auto y = |
411 | new_accum.sqrt() + linear.constant(static_cast<T>(2) * l2() * lr()); |
412 | auto pre_shrink = x / y; |
413 | var.device(d) = (linear.abs() > linear.constant(l1() * lr())) |
414 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
415 | |
416 | } else { |
417 | auto y = new_accum.pow(-lr_power()) + |
418 | linear.constant(static_cast<T>(2) * l2() * lr()); |
419 | auto pre_shrink = x / y; |
420 | var.device(d) = (linear.abs() > linear.constant(l1() * lr())) |
421 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
422 | } |
423 | accum.device(d) += grad * grad; |
424 | } |
425 | }; |
426 | |
427 | template <typename T> |
428 | struct ApplyFtrl<CPUDevice, T> { |
429 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
430 | typename TTypes<T>::Flat accum, |
431 | typename TTypes<T>::Flat linear, |
432 | typename TTypes<T>::ConstFlat grad, |
433 | typename TTypes<T>::ConstScalar lr, |
434 | typename TTypes<T>::ConstScalar l1, |
435 | typename TTypes<T>::ConstScalar l2, |
436 | typename TTypes<T>::ConstScalar lr_power) { |
437 | auto new_accum = accum + grad.square(); |
438 | // special case for which lr_power=-0.5. |
439 | if (lr_power() == static_cast<T>(-0.5)) { |
440 | linear.device(d) += grad - (new_accum.sqrt() - accum.sqrt()) / lr() * var; |
441 | } else { |
442 | linear.device(d) += |
443 | grad - |
444 | (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var; |
445 | } |
446 | auto x = (linear.constant(l1()) * linear.sign() - linear); |
447 | if (lr_power() == static_cast<T>(-0.5)) { |
448 | auto y = new_accum.sqrt() / new_accum.constant(lr()) + |
449 | linear.constant(static_cast<T>(2) * l2()); |
450 | auto pre_shrink = x / y; |
451 | var.device(d) = (linear.abs() > linear.constant(l1())) |
452 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
453 | |
454 | } else { |
455 | auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) + |
456 | linear.constant(static_cast<T>(2) * l2()); |
457 | auto pre_shrink = x / y; |
458 | var.device(d) = (linear.abs() > linear.constant(l1())) |
459 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
460 | } |
461 | accum.device(d) += grad.square(); |
462 | } |
463 | }; |
464 | |
465 | template <typename T> |
466 | struct ApplyFtrlMultiplyLinearByLr<CPUDevice, T> { |
467 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
468 | typename TTypes<T>::Flat accum, |
469 | typename TTypes<T>::Flat linear, |
470 | typename TTypes<T>::ConstFlat grad, |
471 | typename TTypes<T>::ConstScalar lr, |
472 | typename TTypes<T>::ConstScalar l1, |
473 | typename TTypes<T>::ConstScalar l2, |
474 | typename TTypes<T>::ConstScalar lr_power) { |
475 | auto new_accum = accum + grad.square(); |
476 | // special case for which lr_power=-0.5. |
477 | if (lr_power() == static_cast<T>(-0.5)) { |
478 | linear.device(d) += grad * lr() - (new_accum.sqrt() - accum.sqrt()) * var; |
479 | } else { |
480 | linear.device(d) += |
481 | grad * lr() - |
482 | (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) * var; |
483 | } |
484 | auto x = (linear.constant(l1()) * lr() * linear.sign() - linear); |
485 | if (lr_power() == static_cast<T>(-0.5)) { |
486 | auto y = |
487 | new_accum.sqrt() + linear.constant(static_cast<T>(2) * l2() * lr()); |
488 | auto pre_shrink = x / y; |
489 | var.device(d) = (linear.abs() > linear.constant(l1() * lr())) |
490 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
491 | |
492 | } else { |
493 | auto y = new_accum.pow(-lr_power()) + |
494 | linear.constant(static_cast<T>(2) * l2() * lr()); |
495 | auto pre_shrink = x / y; |
496 | var.device(d) = (linear.abs() > linear.constant(l1() * lr())) |
497 | .select(pre_shrink, var.constant(static_cast<T>(0))); |
498 | } |
499 | accum.device(d) += grad.square(); |
500 | } |
501 | }; |
502 | |
503 | namespace { |
504 | |
505 | template <typename T> |
506 | inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, |
507 | const T& l2, const T& lr_power, |
508 | const bool multiply_linear_by_lr) { |
509 | T quadratic; |
510 | if (multiply_linear_by_lr) { |
511 | if (lr_power == static_cast<T>(-0.5)) { |
512 | quadratic = Eigen::numext::sqrt(accum) + static_cast<T>(2) * l2 * lr; |
513 | } else { |
514 | quadratic = |
515 | Eigen::numext::pow(accum, -lr_power) + static_cast<T>(2) * l2 * lr; |
516 | } |
517 | auto l1_reg_adjust = std::max(std::min(linear, l1 * lr), -l1 * lr); |
518 | return (l1_reg_adjust - linear) / quadratic; |
519 | } else { |
520 | if (lr_power == static_cast<T>(-0.5)) { |
521 | quadratic = Eigen::numext::sqrt(accum) / lr + static_cast<T>(2) * l2; |
522 | } else { |
523 | quadratic = |
524 | Eigen::numext::pow(accum, -lr_power) / lr + static_cast<T>(2) * l2; |
525 | } |
526 | auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); |
527 | return (l1_reg_adjust - linear) / quadratic; |
528 | } |
529 | } |
530 | |
531 | template <typename T, typename GradTy, typename GradeMaybeWithShrinkageTy, |
532 | typename AccumTy, typename LinearTy, typename VarTy> |
533 | void ComputeFtrl(GradTy grad, |
534 | GradeMaybeWithShrinkageTy grad_maybe_with_shrinkage, |
535 | AccumTy accum, LinearTy linear, VarTy var, T l1_scalar, |
536 | T l2_scalar, bool multiply_linear_by_lr, T lr_power_scalar, |
537 | T lr_scalar) { |
538 | auto new_accum = accum + grad.square(); |
539 | if (multiply_linear_by_lr) { |
540 | if (lr_power_scalar == static_cast<T>(-0.5)) { |
541 | linear += grad_maybe_with_shrinkage * lr_scalar - |
542 | (new_accum.sqrt() - accum.sqrt()) * var; |
543 | } else { |
544 | linear += |
545 | grad_maybe_with_shrinkage * lr_scalar - |
546 | (new_accum.pow(-lr_power_scalar) - accum.pow(-lr_power_scalar)) * var; |
547 | } |
548 | } else { |
549 | if (lr_power_scalar == static_cast<T>(-0.5)) { |
550 | linear += grad_maybe_with_shrinkage - |
551 | (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; |
552 | } else { |
553 | linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - |
554 | accum.pow(-lr_power_scalar)) / |
555 | lr_scalar * var; |
556 | } |
557 | } |
558 | auto l1_reg_adjust = |
559 | (multiply_linear_by_lr ? linear.cwiseMin(l1_scalar * lr_scalar) |
560 | .cwiseMax(-l1_scalar * lr_scalar) |
561 | : linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar)); |
562 | auto x = l1_reg_adjust - linear; |
563 | if (multiply_linear_by_lr) { |
564 | if (lr_power_scalar == static_cast<T>(-0.5)) { |
565 | auto y = new_accum.sqrt() + |
566 | linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar); |
567 | var = x / y; |
568 | } else { |
569 | auto y = new_accum.pow(-lr_power_scalar) + |
570 | linear.constant(static_cast<T>(2) * l2_scalar * lr_scalar); |
571 | var = x / y; |
572 | } |
573 | } else { |
574 | if (lr_power_scalar == static_cast<T>(-0.5)) { |
575 | auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + |
576 | linear.constant(static_cast<T>(2) * l2_scalar); |
577 | var = x / y; |
578 | } else { |
579 | auto y = new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + |
580 | linear.constant(static_cast<T>(2) * l2_scalar); |
581 | var = x / y; |
582 | } |
583 | } |
584 | accum += grad.square(); |
585 | } |
586 | } // namespace |
587 | |
588 | template <typename T, typename Tindex, bool has_l2_shrinkage> |
589 | struct SparseApplyFtrl<CPUDevice, T, Tindex, has_l2_shrinkage> { |
590 | Status operator()(const CPUDevice& d, typename TTypes<T>::Matrix var_flat, |
591 | typename TTypes<T>::Matrix accum_flat, |
592 | typename TTypes<T>::Matrix linear_flat, |
593 | typename TTypes<T>::ConstScalar lr, |
594 | typename TTypes<T>::ConstScalar l1, |
595 | typename TTypes<T>::ConstScalar l2, |
596 | typename TTypes<T>::ConstScalar l2_shrinkage, |
597 | typename TTypes<T>::ConstScalar lr_power, |
598 | typename TTypes<T>::ConstMatrix grad_flat, |
599 | typename TTypes<Tindex>::ConstVec indices_vec, |
600 | int64_t inner_dim, bool multiply_linear_by_lr) { |
601 | const Tindex N = static_cast<Tindex>(indices_vec.dimension(0)); |
602 | if (N > 0) { |
603 | T lr_scalar = lr(); |
604 | T l1_scalar = l1(); |
605 | T l2_scalar = l2(); |
606 | T l2_shrinkage_scalar; |
607 | if (has_l2_shrinkage) { |
608 | l2_shrinkage_scalar = l2_shrinkage(); |
609 | } |
610 | T lr_power_scalar = lr_power(); |
611 | if (inner_dim > 1) { |
612 | const Tindex first_dim_size = |
613 | static_cast<Tindex>(var_flat.dimension(0)); |
614 | |
615 | for (Tindex i = 0; i < N; i++) { |
616 | const Tindex index = internal::SubtleMustCopy(indices_vec(i)); |
617 | if (!FastBoundsCheck(index, first_dim_size)) { |
618 | return errors::InvalidArgument( |
619 | strings::StrCat("Index " , index, " at offset " , i, |
620 | " in indices is out of range" )); |
621 | } |
622 | auto accum = accum_flat.template chip<0>(index); |
623 | auto linear = linear_flat.template chip<0>(index); |
624 | auto grad = grad_flat.template chip<0>(i); |
625 | auto var = var_flat.template chip<0>(index); |
626 | |
627 | if (has_l2_shrinkage) { |
628 | auto grad_with_shrinkage = |
629 | grad + static_cast<T>(2) * l2_shrinkage_scalar * var; |
630 | ComputeFtrl(/*grad=*/grad, |
631 | /*grad_maybe_with_shrinkage=*/grad_with_shrinkage, |
632 | /*accum=*/accum, /*linear=*/linear, /*var=*/var, |
633 | /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar, |
634 | /*multiply_linear_by_lr=*/multiply_linear_by_lr, |
635 | /*lr_power_scalar=*/lr_power_scalar, |
636 | /*lr_scalar=*/lr_scalar); |
637 | } else { |
638 | ComputeFtrl(/*grad=*/grad, /*grad_maybe_with_shrinkage=*/grad, |
639 | /*accum=*/accum, /*linear=*/linear, /*var=*/var, |
640 | /*l1_scalar=*/l1_scalar, /*l2_scalar=*/l2_scalar, |
641 | /*multiply_linear_by_lr=*/multiply_linear_by_lr, |
642 | /*lr_power_scalar=*/lr_power_scalar, |
643 | /*lr_scalar=*/lr_scalar); |
644 | } |
645 | } |
646 | } else { |
647 | const Tindex first_dim_size = accum_flat.size(); |
648 | |
649 | for (Tindex i = 0; i < N; i++) { |
650 | const Tindex index = internal::SubtleMustCopy(indices_vec(i)); |
651 | if (!FastBoundsCheck(index, first_dim_size)) { |
652 | return errors::InvalidArgument( |
653 | strings::StrCat("Index " , index, " at offset " , i, |
654 | " in indices is out of range" )); |
655 | } |
656 | T& a = accum_flat(index); |
657 | T& l = linear_flat(index); |
658 | T& v = var_flat(index); |
659 | T g; |
660 | if (has_l2_shrinkage) { |
661 | g = grad_flat(i) + |
662 | (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(index)); |
663 | } else { |
664 | g = grad_flat(i); |
665 | } |
666 | |
667 | T updated_a = a + grad_flat(i) * grad_flat(i); |
668 | using Eigen::numext::pow; |
669 | T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar); |
670 | if (!multiply_linear_by_lr) { |
671 | sigma /= lr_scalar; |
672 | } |
673 | T updated_l = (multiply_linear_by_lr ? l + g * lr_scalar - sigma * v |
674 | : l + g - sigma * v); |
675 | v = FtrlCompute(updated_a, updated_l, lr_scalar, l1_scalar, l2_scalar, |
676 | lr_power_scalar, multiply_linear_by_lr); |
677 | a = updated_a; |
678 | l = updated_l; |
679 | } |
680 | } |
681 | } |
682 | return OkStatus(); |
683 | } |
684 | }; |
685 | |
686 | template <typename T> |
687 | struct ApplyMomentum<CPUDevice, T> { |
688 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
689 | typename TTypes<T>::Flat accum, |
690 | typename TTypes<T>::ConstScalar lr, |
691 | typename TTypes<T>::ConstFlat grad, |
692 | typename TTypes<T>::ConstScalar momentum, bool use_nesterov) { |
693 | accum.device(d) = accum * momentum() + grad; |
694 | if (use_nesterov) { |
695 | var.device(d) -= grad * lr() + accum * momentum() * lr(); |
696 | } else { |
697 | var.device(d) -= accum * lr(); |
698 | } |
699 | } |
700 | }; |
701 | |
702 | template <typename T> |
703 | struct ApplyKerasMomentum<CPUDevice, T> { |
704 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
705 | typename TTypes<T>::Flat accum, |
706 | typename TTypes<T>::ConstScalar lr, |
707 | typename TTypes<T>::ConstFlat grad, |
708 | typename TTypes<T>::ConstScalar momentum, bool use_nesterov) { |
709 | accum.device(d) = accum * momentum() - grad * lr(); |
710 | if (use_nesterov) { |
711 | var.device(d) += (accum * momentum() - grad * lr()); |
712 | } else { |
713 | var.device(d) += accum; |
714 | } |
715 | } |
716 | }; |
717 | |
718 | template <typename T, typename Tindex> |
719 | struct SparseApplyKerasMomentum<CPUDevice, T, Tindex> { |
720 | Tindex operator()(const CPUDevice& d, typename TTypes<T>::Matrix var, |
721 | typename TTypes<T>::Matrix accum, |
722 | typename TTypes<T>::ConstScalar lr, |
723 | typename TTypes<T>::ConstMatrix grad, |
724 | typename TTypes<Tindex>::ConstFlat indices, |
725 | typename TTypes<T>::ConstScalar momentum, |
726 | bool use_nesterov) { |
727 | const Tindex N = static_cast<Tindex>(indices.size()); |
728 | const Tindex first_dim_size = static_cast<Tindex>(var.dimension(0)); |
729 | for (Tindex i = 0; i < N; i++) { |
730 | const Tindex index = internal::SubtleMustCopy(indices(i)); |
731 | if (!FastBoundsCheck(index, first_dim_size)) return i; |
732 | auto a = accum.template chip<0>(index); |
733 | auto g = grad.template chip<0>(i); |
734 | auto v = var.template chip<0>(index); |
735 | a = a * a.constant(momentum()) - g * g.constant(lr()); |
736 | if (use_nesterov) { |
737 | v += a * a.constant(momentum()) - g * g.constant(lr()); |
738 | } else { |
739 | v += a; |
740 | } |
741 | } |
742 | return -1; |
743 | } |
744 | }; |
745 | |
746 | template <typename T, typename Tindex> |
747 | struct SparseApplyAdadelta<CPUDevice, T, Tindex> { |
748 | void operator()(const CPUDevice& d, typename TTypes<T>::Matrix var, |
749 | typename TTypes<T>::Matrix accum, |
750 | typename TTypes<T>::Matrix accum_update, |
751 | typename TTypes<T>::ConstScalar lr, |
752 | typename TTypes<T>::ConstScalar rho, |
753 | typename TTypes<T>::ConstScalar epsilon, |
754 | typename TTypes<T>::ConstMatrix grad, |
755 | typename TTypes<Tindex>::ConstFlat indices) { |
756 | const Tindex N = static_cast<Tindex>(indices.size()); |
757 | for (Tindex i = 0; i < N; i++) { |
758 | const Tindex index = indices(i); |
759 | auto a = accum.template chip<0>(index); |
760 | auto a_update = accum_update.template chip<0>(index); |
761 | auto g = grad.template chip<0>(i); |
762 | |
763 | a = a * a.constant(rho()) + g.square() * g.constant(T(1) - rho()); |
764 | const auto update = (a_update + a_update.constant(epsilon())).sqrt() * |
765 | (a + a.constant(epsilon())).rsqrt() * g; |
766 | auto v = var.template chip<0>(index); |
767 | v -= update * update.constant(lr()); |
768 | a_update = a_update * a_update.constant(rho()) + |
769 | update.square() * update.constant(static_cast<T>(1) - rho()); |
770 | } |
771 | } |
772 | }; |
773 | |
774 | template <typename Device, typename T> |
775 | struct ApplyAdamNonCuda { |
776 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
777 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, |
778 | typename TTypes<T>::ConstScalar beta1_power, |
779 | typename TTypes<T>::ConstScalar beta2_power, |
780 | typename TTypes<T>::ConstScalar lr, |
781 | typename TTypes<T>::ConstScalar beta1, |
782 | typename TTypes<T>::ConstScalar beta2, |
783 | typename TTypes<T>::ConstScalar epsilon, |
784 | typename TTypes<T>::ConstFlat grad, bool use_nesterov) { |
785 | // Get params length and check if they can be vectorized by packet size. |
786 | Index length = var.size(); |
787 | Index packet_size = Eigen::internal::packet_traits<T>::size; |
788 | if (length % packet_size == 0) { |
789 | length = length / packet_size; |
790 | } else { |
791 | packet_size = 1; |
792 | } |
793 | |
794 | T* var_ptr = var.data(); |
795 | T* m_ptr = m.data(); |
796 | T* v_ptr = v.data(); |
797 | const T* g_ptr = grad.data(); |
798 | const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / |
799 | (T(1) - beta1_power()); |
800 | // beta1 == μ |
801 | // beta2 == ν |
802 | // v == n |
803 | // var == θ |
804 | |
805 | auto shard = [var_ptr, m_ptr, v_ptr, g_ptr, alpha, beta1, beta2, epsilon, |
806 | use_nesterov, packet_size](int begin, int end) { |
807 | int t_size = (end - begin) * packet_size; |
808 | begin = begin * packet_size; |
809 | auto var = typename TTypes<T>::UnalignedTensor(var_ptr + begin, t_size); |
810 | auto m = typename TTypes<T>::UnalignedTensor(m_ptr + begin, t_size); |
811 | auto v = typename TTypes<T>::UnalignedTensor(v_ptr + begin, t_size); |
812 | auto g = typename TTypes<T>::UnalignedConstTensor(g_ptr + begin, t_size); |
813 | |
814 | if (use_nesterov) { |
815 | m += (g - m) * (T(1) - beta1()); |
816 | v += (g.square() - v) * (T(1) - beta2()); |
817 | var -= ((g * (T(1) - beta1()) + beta1() * m) * alpha) / |
818 | (v.sqrt() + epsilon()); |
819 | } else { |
820 | m += (g - m) * (T(1) - beta1()); |
821 | v += (g.square() - v) * (T(1) - beta2()); |
822 | var -= (m * alpha) / (v.sqrt() + epsilon()); |
823 | } |
824 | }; |
825 | |
826 | // Input data: var, v, m, grad. |
827 | // Output data: var, v, m. |
828 | const int input_bytes = length * packet_size * sizeof(T) * 4; |
829 | const int output_bytes = length * packet_size * sizeof(T) * 3; |
830 | const int compute_cycles = |
831 | // Consider Sub as Add |
832 | (Eigen::TensorOpCost::AddCost<int>() * 5 + |
833 | Eigen::TensorOpCost::MulCost<int>() * 2 + |
834 | Eigen::TensorOpCost::AddCost<T>() * 10 + |
835 | Eigen::TensorOpCost::MulCost<T>() * 6 + |
836 | Eigen::TensorOpCost::DivCost<T>()) * |
837 | length; |
838 | const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); |
839 | |
840 | // Eigen device must update 3 variables with 3 different expressions, |
841 | // which is bad for cache locality on CPU. Here use ParallelFor instead of |
842 | // "regular" tensor expressions to get better performance. |
843 | d.parallelFor(length, cost, shard); |
844 | } |
845 | }; |
846 | |
847 | template <typename T> |
848 | struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {}; |
849 | |
850 | template <typename T> |
851 | struct ApplyAdamWithAmsgrad<CPUDevice, T> { |
852 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
853 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, |
854 | typename TTypes<T>::Flat vhat, |
855 | typename TTypes<T>::ConstScalar beta1_power, |
856 | typename TTypes<T>::ConstScalar beta2_power, |
857 | typename TTypes<T>::ConstScalar lr, |
858 | typename TTypes<T>::ConstScalar beta1, |
859 | typename TTypes<T>::ConstScalar beta2, |
860 | typename TTypes<T>::ConstScalar epsilon, |
861 | typename TTypes<T>::ConstFlat grad) { |
862 | const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / |
863 | (T(1) - beta1_power()); |
864 | |
865 | m.device(d) += (grad - m) * (T(1) - beta1()); |
866 | v.device(d) += (grad.square() - v) * (T(1) - beta2()); |
867 | vhat.device(d) = vhat.cwiseMax(v); |
868 | var.device(d) -= (m * alpha) / (vhat.sqrt() + epsilon()); |
869 | } |
870 | }; |
871 | |
872 | template <typename Device, typename T> |
873 | struct ApplyAdaMaxNonCuda { |
874 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
875 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, |
876 | typename TTypes<T>::ConstScalar beta1_power, |
877 | typename TTypes<T>::ConstScalar lr, |
878 | typename TTypes<T>::ConstScalar beta1, |
879 | typename TTypes<T>::ConstScalar beta2, |
880 | typename TTypes<T>::ConstScalar epsilon, |
881 | typename TTypes<T>::ConstFlat grad) { |
882 | m.device(d) += (grad - m) * (T(1) - beta1()); |
883 | // Here v is u in section 7.1 |
884 | v.device(d) = (beta2() * v).cwiseMax(grad.abs()); |
885 | // var is θ in section 7.1 |
886 | var.device(d) -= lr() / (T(1) - beta1_power()) * (m / (v + epsilon())); |
887 | } |
888 | }; |
889 | |
890 | template <typename T> |
891 | struct ApplyAdaMax<CPUDevice, T> : ApplyAdaMaxNonCuda<CPUDevice, T> {}; |
892 | |
893 | template <typename T> |
894 | struct ApplyRMSProp<CPUDevice, T> { |
895 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
896 | typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, |
897 | typename TTypes<T>::ConstScalar lr, |
898 | typename TTypes<T>::ConstScalar rho, |
899 | typename TTypes<T>::ConstScalar momentum, |
900 | typename TTypes<T>::ConstScalar epsilon, |
901 | typename TTypes<T>::ConstFlat grad) { |
902 | ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho()); |
903 | mom.device(d) = |
904 | mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt()); |
905 | var.device(d) -= mom; |
906 | } |
907 | }; |
908 | |
909 | template <typename T> |
910 | struct ApplyCenteredRMSProp<CPUDevice, T> { |
911 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
912 | typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, |
913 | typename TTypes<T>::Flat mom, |
914 | typename TTypes<T>::ConstScalar lr, |
915 | typename TTypes<T>::ConstScalar rho, |
916 | typename TTypes<T>::ConstScalar momentum, |
917 | typename TTypes<T>::ConstScalar epsilon, |
918 | typename TTypes<T>::ConstFlat grad) { |
919 | ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho()); |
920 | mg.device(d) += (grad - mg) * (static_cast<T>(1) - rho()); |
921 | auto denom = (ms - mg.square()) + epsilon(); |
922 | mom.device(d) = mom * momentum() + (grad * lr()) / denom.sqrt(); |
923 | var.device(d) -= mom; |
924 | } |
925 | }; |
926 | |
927 | template <typename T> |
928 | struct ApplyAddSign<CPUDevice, T> { |
929 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
930 | typename TTypes<T>::Flat m, |
931 | typename TTypes<T>::ConstScalar lr, |
932 | typename TTypes<T>::ConstScalar alpha, |
933 | typename TTypes<T>::ConstScalar sign_decay, |
934 | typename TTypes<T>::ConstScalar beta, |
935 | typename TTypes<T>::ConstFlat grad) { |
936 | m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta()); |
937 | auto sign_gm = grad.sign() * m.sign(); |
938 | var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad; |
939 | } |
940 | }; |
941 | |
942 | template <typename T> |
943 | struct ApplyPowerSign<CPUDevice, T> { |
944 | void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, |
945 | typename TTypes<T>::Flat m, |
946 | typename TTypes<T>::ConstScalar lr, |
947 | typename TTypes<T>::ConstScalar logbase, |
948 | typename TTypes<T>::ConstScalar sign_decay, |
949 | typename TTypes<T>::ConstScalar beta, |
950 | typename TTypes<T>::ConstFlat grad) { |
951 | m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta()); |
952 | auto sign_gm = grad.sign() * m.sign(); |
953 | auto grad_scale = (logbase() * sign_decay() * sign_gm).exp(); |
954 | var.device(d) -= lr() * grad_scale * grad; |
955 | } |
956 | }; |
957 | |
958 | } // namespace functor |
959 | |
960 | template <typename Device, typename T> |
961 | class ApplyGradientDescentOp : public OpKernel { |
962 | public: |
963 | explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
964 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
965 | } |
966 | |
967 | void Compute(OpKernelContext* ctx) override { |
968 | const bool sparse = false; |
969 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
970 | ctx, use_exclusive_lock_, sparse, {0}); |
971 | Tensor var; |
972 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
973 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
974 | |
975 | OP_REQUIRES( |
976 | ctx, var.IsInitialized(), |
977 | errors::FailedPrecondition( |
978 | "Attempting to use uninitialized variables: " , requested_input(0))); |
979 | const Tensor& alpha = ctx->input(1); |
980 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), |
981 | errors::InvalidArgument("alpha is not a scalar: " , |
982 | alpha.shape().DebugString())); |
983 | const Tensor& delta = ctx->input(2); |
984 | OP_REQUIRES( |
985 | ctx, var.shape().IsSameSize(delta.shape()), |
986 | errors::InvalidArgument("var and delta do not have the same shape" , |
987 | var.shape().DebugString(), " " , |
988 | delta.shape().DebugString())); |
989 | |
990 | const Device& device = ctx->template eigen_device<Device>(); |
991 | functor::ApplyGradientDescent<Device, T>()( |
992 | device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>()); |
993 | |
994 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
995 | } |
996 | |
997 | private: |
998 | bool use_exclusive_lock_; |
999 | }; |
1000 | |
1001 | #define REGISTER_KERNELS(D, T) \ |
1002 | REGISTER_KERNEL_BUILDER( \ |
1003 | Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
1004 | ApplyGradientDescentOp<D##Device, T>); \ |
1005 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyGradientDescent") \ |
1006 | .Device(DEVICE_##D) \ |
1007 | .HostMemory("var") \ |
1008 | .TypeConstraint<T>("T"), \ |
1009 | ApplyGradientDescentOp<D##Device, T>); |
1010 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
1011 | |
1012 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
1013 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
1014 | |
1015 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1016 | // Forward declarations of the functor specializations for GPU. |
1017 | namespace functor { |
1018 | #define DECLARE_GPU_SPEC(T) \ |
1019 | template <> \ |
1020 | void ApplyGradientDescent<GPUDevice, T>::operator()( \ |
1021 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
1022 | typename TTypes<T>::ConstScalar alpha, \ |
1023 | typename TTypes<T>::ConstFlat delta); \ |
1024 | extern template struct ApplyGradientDescent<GPUDevice, T>; |
1025 | DECLARE_GPU_SPEC(Eigen::half); |
1026 | DECLARE_GPU_SPEC(float); |
1027 | DECLARE_GPU_SPEC(double); |
1028 | DECLARE_GPU_SPEC(complex64); |
1029 | DECLARE_GPU_SPEC(complex128); |
1030 | #undef DECLARE_GPU_SPEC |
1031 | } // namespace functor |
1032 | |
1033 | REGISTER_KERNELS(GPU, Eigen::half); |
1034 | REGISTER_KERNELS(GPU, float); |
1035 | REGISTER_KERNELS(GPU, double); |
1036 | REGISTER_KERNELS(GPU, complex64); |
1037 | REGISTER_KERNELS(GPU, complex128); |
1038 | #endif |
1039 | |
1040 | #undef REGISTER_CPU_KERNELS |
1041 | #undef REGISTER_KERNELS |
1042 | |
1043 | template <typename Device, typename T> |
1044 | class ApplyAdadeltaOp : public OpKernel { |
1045 | public: |
1046 | explicit ApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1047 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1048 | } |
1049 | |
1050 | void Compute(OpKernelContext* ctx) override { |
1051 | const bool sparse = false; |
1052 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
1053 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
1054 | DoValidate(ctx); |
1055 | if (!ctx->status().ok()) return; |
1056 | DoCompute(ctx); |
1057 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1058 | } |
1059 | |
1060 | private: |
1061 | bool use_exclusive_lock_; |
1062 | |
1063 | void DoValidate(OpKernelContext* ctx) { |
1064 | Tensor var; |
1065 | const bool sparse = false; |
1066 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1067 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1068 | Tensor accum; |
1069 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1070 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
1071 | Tensor accum_update; |
1072 | OP_REQUIRES_OK( |
1073 | ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_, |
1074 | sparse, &accum_update)); |
1075 | |
1076 | OP_REQUIRES( |
1077 | ctx, var.IsInitialized(), |
1078 | errors::FailedPrecondition( |
1079 | "Attempting to use uninitialized variables: " , requested_input(0))); |
1080 | OP_REQUIRES( |
1081 | ctx, accum.IsInitialized(), |
1082 | errors::FailedPrecondition( |
1083 | "Attempting to use uninitialized variables: " , requested_input(1))); |
1084 | OP_REQUIRES( |
1085 | ctx, accum_update.IsInitialized(), |
1086 | errors::FailedPrecondition( |
1087 | "Attempting to use uninitialized variables: " , requested_input(2))); |
1088 | |
1089 | const Tensor& lr = ctx->input(3); |
1090 | const Tensor& rho = ctx->input(4); |
1091 | const Tensor& epsilon = ctx->input(5); |
1092 | const Tensor& grad = ctx->input(6); |
1093 | |
1094 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
1095 | errors::InvalidArgument("lr is not a scalar: " , |
1096 | lr.shape().DebugString())); |
1097 | |
1098 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), |
1099 | errors::InvalidArgument("rho is not a scalar: " , |
1100 | rho.shape().DebugString())); |
1101 | |
1102 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
1103 | errors::InvalidArgument("epsilon is not a scalar: " , |
1104 | epsilon.shape().DebugString())); |
1105 | |
1106 | OP_REQUIRES( |
1107 | ctx, var.shape().IsSameSize(accum.shape()), |
1108 | errors::InvalidArgument("var and accum do not have the same shape" , |
1109 | var.shape().DebugString(), " " , |
1110 | accum.shape().DebugString())); |
1111 | OP_REQUIRES( |
1112 | ctx, var.shape().IsSameSize(grad.shape()), |
1113 | errors::InvalidArgument("var and grad do not have the same shape" , |
1114 | var.shape().DebugString(), " " , |
1115 | grad.shape().DebugString())); |
1116 | } |
1117 | |
1118 | void DoCompute(OpKernelContext* ctx) { |
1119 | const Device& device = ctx->template eigen_device<Device>(); |
1120 | Tensor var; |
1121 | const bool sparse = false; |
1122 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1123 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1124 | Tensor accum; |
1125 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1126 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
1127 | Tensor accum_update; |
1128 | OP_REQUIRES_OK( |
1129 | ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_, |
1130 | sparse, &accum_update)); |
1131 | |
1132 | const Tensor& lr = ctx->input(3); |
1133 | const Tensor& rho = ctx->input(4); |
1134 | const Tensor& epsilon = ctx->input(5); |
1135 | const Tensor& grad = ctx->input(6); |
1136 | |
1137 | functor::ApplyAdadelta<Device, T>()( |
1138 | device, var.flat<T>(), accum.flat<T>(), accum_update.flat<T>(), |
1139 | lr.scalar<T>(), rho.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>()); |
1140 | } |
1141 | }; |
1142 | |
1143 | #define REGISTER_KERNELS(D, T) \ |
1144 | REGISTER_KERNEL_BUILDER( \ |
1145 | Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
1146 | ApplyAdadeltaOp<D##Device, T>); \ |
1147 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdadelta") \ |
1148 | .Device(DEVICE_##D) \ |
1149 | .HostMemory("var") \ |
1150 | .HostMemory("accum") \ |
1151 | .HostMemory("accum_update") \ |
1152 | .TypeConstraint<T>("T"), \ |
1153 | ApplyAdadeltaOp<D##Device, T>); |
1154 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
1155 | |
1156 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
1157 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
1158 | |
1159 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1160 | // Forward declarations of the functor specializations for GPU. |
1161 | namespace functor { |
1162 | #define DECLARE_GPU_SPEC(T) \ |
1163 | template <> \ |
1164 | void ApplyAdadelta<GPUDevice, T>::operator()( \ |
1165 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
1166 | typename TTypes<T>::Flat accum, typename TTypes<T>::Flat accum_update, \ |
1167 | typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \ |
1168 | typename TTypes<T>::ConstScalar epsilon, \ |
1169 | typename TTypes<T>::ConstFlat grad); \ |
1170 | extern template struct ApplyAdadelta<GPUDevice, T>; |
1171 | DECLARE_GPU_SPEC(Eigen::half); |
1172 | DECLARE_GPU_SPEC(float); |
1173 | DECLARE_GPU_SPEC(double); |
1174 | DECLARE_GPU_SPEC(complex64); |
1175 | DECLARE_GPU_SPEC(complex128); |
1176 | #undef DECLARE_GPU_SPEC |
1177 | } // namespace functor |
1178 | |
1179 | REGISTER_KERNELS(GPU, Eigen::half); |
1180 | REGISTER_KERNELS(GPU, float); |
1181 | REGISTER_KERNELS(GPU, double); |
1182 | REGISTER_KERNELS(GPU, complex64); |
1183 | REGISTER_KERNELS(GPU, complex128); |
1184 | #endif |
1185 | #undef REGISTER_CPU_KERNELS |
1186 | #undef REGISTER_KERNELS |
1187 | |
1188 | template <typename T, typename Device, typename Tindex> |
1189 | class SparseApplyAdadeltaOp : public OpKernel { |
1190 | public: |
1191 | explicit SparseApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1192 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1193 | } |
1194 | |
1195 | void Compute(OpKernelContext* ctx) override { |
1196 | const bool sparse = true; |
1197 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
1198 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
1199 | DoCompute(ctx); |
1200 | } |
1201 | |
1202 | void DoCompute(OpKernelContext* ctx) { |
1203 | Tensor var; |
1204 | const bool sparse = true; |
1205 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1206 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1207 | Tensor accum_grad; |
1208 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1209 | ctx, 1, use_exclusive_lock_, sparse, &accum_grad)); |
1210 | Tensor accum_update; |
1211 | OP_REQUIRES_OK( |
1212 | ctx, GetInputTensorFromVariable<Device, T>(ctx, 2, use_exclusive_lock_, |
1213 | sparse, &accum_update)); |
1214 | OP_REQUIRES( |
1215 | ctx, var.IsInitialized(), |
1216 | errors::FailedPrecondition( |
1217 | "Attempting to use uninitialized variables: " , requested_input(0))); |
1218 | OP_REQUIRES( |
1219 | ctx, accum_grad.IsInitialized(), |
1220 | errors::FailedPrecondition( |
1221 | "Attempting to use uninitialized variables: " , requested_input(1))); |
1222 | OP_REQUIRES( |
1223 | ctx, accum_update.IsInitialized(), |
1224 | errors::FailedPrecondition( |
1225 | "Attempting to use uninitialized variables: " , requested_input(2))); |
1226 | OP_REQUIRES( |
1227 | ctx, var.shape().IsSameSize(accum_grad.shape()), |
1228 | errors::InvalidArgument("var and accum_grad do not have the same shape" , |
1229 | var.shape().DebugString(), " " , |
1230 | accum_grad.shape().DebugString())); |
1231 | OP_REQUIRES(ctx, var.shape().IsSameSize(accum_update.shape()), |
1232 | errors::InvalidArgument( |
1233 | "var and accum_update do not have the same shape" , |
1234 | var.shape().DebugString(), " " , |
1235 | accum_update.shape().DebugString())); |
1236 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
1237 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
1238 | |
1239 | const Tensor& lr = ctx->input(3); |
1240 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
1241 | errors::InvalidArgument("lr is not a scalar: " , |
1242 | lr.shape().DebugString())); |
1243 | const Tensor& rho = ctx->input(4); |
1244 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), |
1245 | errors::InvalidArgument("rho is not a scalar: " , |
1246 | rho.shape().DebugString())); |
1247 | const Tensor& epsilon = ctx->input(5); |
1248 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
1249 | errors::InvalidArgument("epsilon is not a scalar: " , |
1250 | epsilon.shape().DebugString())); |
1251 | const Tensor& grad = ctx->input(6); |
1252 | const Tensor& indices = ctx->input(7); |
1253 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
1254 | errors::InvalidArgument("indices must be one-dimensional" )); |
1255 | |
1256 | for (int d = 1; d < var.dims(); d++) { |
1257 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
1258 | errors::InvalidArgument(strings::StrCat( |
1259 | "var and grad must match in dimension " , d))); |
1260 | } |
1261 | const Tindex N = indices.dim_size(0); |
1262 | OP_REQUIRES( |
1263 | ctx, grad.dim_size(0) == N, |
1264 | errors::InvalidArgument( |
1265 | "grad must be the same size as indices in the first dimension." )); |
1266 | |
1267 | if (N > 0) { |
1268 | const Tindex first_dim_size = var.dim_size(0); |
1269 | // Validate all the indices are in range |
1270 | auto indices_vec = indices.vec<Tindex>(); |
1271 | for (Tindex i = 0; i < N; i++) { |
1272 | const Tindex index = indices_vec(i); |
1273 | OP_REQUIRES(ctx, |
1274 | (!std::is_same<Device, CPUDevice>::value || |
1275 | (index >= 0 && index < first_dim_size)), |
1276 | errors::InvalidArgument( |
1277 | strings::StrCat("Index " , index, " at offset " , i, |
1278 | " in indices is out of range" ))); |
1279 | } |
1280 | |
1281 | const Device& device = ctx->template eigen_device<Device>(); |
1282 | functor::SparseApplyAdadelta<Device, T, Tindex>()( |
1283 | device, var.flat_outer_dims<T>(), accum_grad.flat_outer_dims<T>(), |
1284 | accum_update.flat_outer_dims<T>(), lr.scalar<T>(), rho.scalar<T>(), |
1285 | epsilon.scalar<T>(), grad.flat_outer_dims<T>(), indices_vec); |
1286 | } |
1287 | |
1288 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1289 | } |
1290 | |
1291 | private: |
1292 | bool use_exclusive_lock_; |
1293 | }; |
1294 | |
1295 | #define REGISTER_KERNELS(T, D, Tindices) \ |
1296 | REGISTER_KERNEL_BUILDER(Name("SparseApplyAdadelta") \ |
1297 | .Device(DEVICE_##D) \ |
1298 | .TypeConstraint<T>("T") \ |
1299 | .TypeConstraint<Tindices>("Tindices"), \ |
1300 | SparseApplyAdadeltaOp<T, D##Device, Tindices>); \ |
1301 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdadelta") \ |
1302 | .Device(DEVICE_##D) \ |
1303 | .TypeConstraint<T>("T") \ |
1304 | .TypeConstraint<Tindices>("Tindices"), \ |
1305 | SparseApplyAdadeltaOp<T, D##Device, Tindices>); |
1306 | #define REGISTER_CPU_KERNELS(T) \ |
1307 | REGISTER_KERNELS(T, CPU, int32); \ |
1308 | REGISTER_KERNELS(T, CPU, int64_t); |
1309 | |
1310 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
1311 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
1312 | |
1313 | #undef REGISTER_CPU_KERNELS |
1314 | |
1315 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1316 | // Forward declarations of the functor specializations for GPU. |
1317 | namespace functor { |
1318 | #define DECLARE_GPU_SPEC(T, Tindex) \ |
1319 | template <> \ |
1320 | void SparseApplyAdadelta<GPUDevice, T, Tindex>::operator()( \ |
1321 | const GPUDevice& d, typename TTypes<T>::Matrix var, \ |
1322 | typename TTypes<T>::Matrix accum, \ |
1323 | typename TTypes<T>::Matrix accum_update, \ |
1324 | typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \ |
1325 | typename TTypes<T>::ConstScalar epsilon, \ |
1326 | typename TTypes<T>::ConstMatrix grad, \ |
1327 | typename TTypes<Tindex>::ConstFlat indices); \ |
1328 | extern template struct SparseApplyAdadelta<GPUDevice, T, Tindex>; |
1329 | DECLARE_GPU_SPEC(Eigen::half, int32); |
1330 | DECLARE_GPU_SPEC(Eigen::half, int64_t); |
1331 | DECLARE_GPU_SPEC(float, int32); |
1332 | DECLARE_GPU_SPEC(float, int64_t); |
1333 | DECLARE_GPU_SPEC(double, int32); |
1334 | DECLARE_GPU_SPEC(double, int64_t); |
1335 | DECLARE_GPU_SPEC(complex64, int32); |
1336 | DECLARE_GPU_SPEC(complex64, int64_t); |
1337 | DECLARE_GPU_SPEC(complex128, int32); |
1338 | DECLARE_GPU_SPEC(complex128, int64_t); |
1339 | #undef DECLARE_GPU_SPEC |
1340 | } // namespace functor |
1341 | |
1342 | #define REGISTER_GPU_KERNELS(T) \ |
1343 | REGISTER_KERNELS(T, GPU, int32); \ |
1344 | REGISTER_KERNELS(T, GPU, int64_t); |
1345 | |
1346 | REGISTER_GPU_KERNELS(Eigen::half); |
1347 | REGISTER_GPU_KERNELS(float); |
1348 | REGISTER_GPU_KERNELS(double); |
1349 | REGISTER_GPU_KERNELS(complex64); |
1350 | REGISTER_GPU_KERNELS(complex128); |
1351 | #undef REGISTER_GPU_KERNELS |
1352 | #endif |
1353 | #undef REGISTER_KERNELS |
1354 | |
1355 | // Note, this op works on cpu only. |
1356 | template <typename Device, typename T> |
1357 | class ApplyProximalGradientDescentOp : public OpKernel { |
1358 | public: |
1359 | explicit ApplyProximalGradientDescentOp(OpKernelConstruction* ctx) |
1360 | : OpKernel(ctx) { |
1361 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1362 | } |
1363 | |
1364 | void Compute(OpKernelContext* ctx) override { |
1365 | const bool sparse = false; |
1366 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
1367 | ctx, use_exclusive_lock_, sparse, {0}); |
1368 | Tensor var; |
1369 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1370 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1371 | |
1372 | OP_REQUIRES( |
1373 | ctx, var.IsInitialized(), |
1374 | errors::FailedPrecondition( |
1375 | "Attempting to use uninitialized variables: " , requested_input(0))); |
1376 | const Tensor& alpha = ctx->input(1); |
1377 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), |
1378 | errors::InvalidArgument("alpha is not a scalar: " , |
1379 | alpha.shape().DebugString())); |
1380 | const Tensor& l1 = ctx->input(2); |
1381 | OP_REQUIRES( |
1382 | ctx, TensorShapeUtils::IsScalar(l1.shape()), |
1383 | errors::InvalidArgument("l1 regularization strength is not a scalar: " , |
1384 | l1.shape().DebugString())); |
1385 | const Tensor& l2 = ctx->input(3); |
1386 | OP_REQUIRES( |
1387 | ctx, TensorShapeUtils::IsScalar(l2.shape()), |
1388 | errors::InvalidArgument("l2 regularization strength is not a scalar: " , |
1389 | l2.shape().DebugString())); |
1390 | |
1391 | const Tensor& delta = ctx->input(4); |
1392 | OP_REQUIRES( |
1393 | ctx, var.shape().IsSameSize(delta.shape()), |
1394 | errors::InvalidArgument("var and delta do not have the same shape" , |
1395 | var.shape().DebugString(), " " , |
1396 | delta.shape().DebugString())); |
1397 | |
1398 | const Device& device = ctx->template eigen_device<Device>(); |
1399 | functor::ApplyProximalGradientDescent<Device, T>()( |
1400 | device, var.flat<T>(), alpha.scalar<T>(), l1.scalar<T>(), |
1401 | l2.scalar<T>(), delta.flat<T>()); |
1402 | |
1403 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1404 | } |
1405 | |
1406 | private: |
1407 | bool use_exclusive_lock_; |
1408 | }; |
1409 | |
1410 | #define REGISTER_KERNELS(D, T) \ |
1411 | REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \ |
1412 | .Device(DEVICE_##D) \ |
1413 | .TypeConstraint<T>("T"), \ |
1414 | ApplyProximalGradientDescentOp<D##Device, T>); \ |
1415 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalGradientDescent") \ |
1416 | .HostMemory("var") \ |
1417 | .Device(DEVICE_##D) \ |
1418 | .TypeConstraint<T>("T"), \ |
1419 | ApplyProximalGradientDescentOp<D##Device, T>); |
1420 | |
1421 | REGISTER_KERNELS(CPU, float); |
1422 | REGISTER_KERNELS(CPU, double); |
1423 | #undef REGISTER_KERNELS |
1424 | |
1425 | // Note, this op works on cpu only. |
1426 | template <typename T, typename Tindex> |
1427 | class SparseApplyProximalGradientDescentOp : public OpKernel { |
1428 | public: |
1429 | explicit SparseApplyProximalGradientDescentOp(OpKernelConstruction* ctx) |
1430 | : OpKernel(ctx) { |
1431 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1432 | } |
1433 | |
1434 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
1435 | const bool sparse = true; |
1436 | auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>( |
1437 | ctx, use_exclusive_lock_, sparse, {0}); |
1438 | Tensor var; |
1439 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
1440 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1441 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
1442 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
1443 | |
1444 | const Tensor& lr = ctx->input(1); |
1445 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
1446 | errors::InvalidArgument("lr is not a scalar: " , |
1447 | lr.shape().DebugString())); |
1448 | const Tensor& l1 = ctx->input(2); |
1449 | OP_REQUIRES( |
1450 | ctx, TensorShapeUtils::IsScalar(l1.shape()), |
1451 | errors::InvalidArgument("l1 regularization strength is not a scalar: " , |
1452 | l1.shape().DebugString())); |
1453 | const Tensor& l2 = ctx->input(3); |
1454 | OP_REQUIRES( |
1455 | ctx, TensorShapeUtils::IsScalar(l2.shape()), |
1456 | errors::InvalidArgument("l2 regularization strength is not a scalar: " , |
1457 | l2.shape().DebugString())); |
1458 | |
1459 | const Tensor& grad = ctx->input(4); |
1460 | const Tensor& indices = ctx->input(5); |
1461 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
1462 | errors::InvalidArgument("indices must be one-dimensional" )); |
1463 | |
1464 | int64_t inner_dim = 1; |
1465 | for (int d = 1; d < var.dims(); d++) { |
1466 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
1467 | errors::InvalidArgument(strings::StrCat( |
1468 | "var and grad must match in dimension " , d))); |
1469 | inner_dim *= grad.dim_size(d); |
1470 | } |
1471 | const Tindex N = indices.dim_size(0); |
1472 | OP_REQUIRES( |
1473 | ctx, grad.dim_size(0) == N, |
1474 | errors::InvalidArgument( |
1475 | "grad must be the same size as indices in the first dimension." )); |
1476 | OP_REQUIRES(ctx, inner_dim > 0, |
1477 | errors::InvalidArgument( |
1478 | "Inner dimension should be greater than zero." )); |
1479 | |
1480 | if (N > 0) { |
1481 | if (inner_dim > 1) { |
1482 | const Tindex first_dim_size = var.dim_size(0); |
1483 | auto indices_vec = indices.vec<Tindex>(); |
1484 | auto var_flat = var.flat_outer_dims<T>(); |
1485 | auto grad_flat = grad.flat_outer_dims<T>(); |
1486 | T lr_scalar = lr.scalar<T>()(); |
1487 | T l1_scalar = l1.scalar<T>()(); |
1488 | T l2_scalar = l2.scalar<T>()(); |
1489 | |
1490 | // TODO(xbing): extract the common logic for the Fobos update. |
1491 | for (Tindex i = 0; i < N; i++) { |
1492 | const Tindex index = internal::SubtleMustCopy(indices_vec(i)); |
1493 | OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), |
1494 | errors::InvalidArgument( |
1495 | strings::StrCat("Index " , index, " at offset " , i, |
1496 | " in indices is out of range" ))); |
1497 | auto g = grad_flat.template chip<0>(i); |
1498 | auto v = var_flat.template chip<0>(index); |
1499 | // compute learning_rate for current step. |
1500 | auto learning_rate = v.constant(lr_scalar); |
1501 | auto prox_v = v; |
1502 | // v = w - g * learning_rate. |
1503 | prox_v -= g * learning_rate; |
1504 | if (l1_scalar > 0) { |
1505 | // compute sign(v) * max(|v|, 0) |
1506 | v = prox_v.sign() * |
1507 | (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) |
1508 | .cwiseMax(static_cast<T>(0.0)) / |
1509 | (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); |
1510 | } else { |
1511 | v = prox_v / |
1512 | (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); |
1513 | } |
1514 | } |
1515 | } else { |
1516 | auto indices_vec = indices.vec<Tindex>(); |
1517 | auto var_flat = var.flat<T>(); |
1518 | auto grad_flat = grad.flat<T>(); |
1519 | T lr_scalar = lr.scalar<T>()(); |
1520 | T l1_scalar = l1.scalar<T>()(); |
1521 | T l2_scalar = l2.scalar<T>()(); |
1522 | const Tindex first_dim_size = var_flat.size(); |
1523 | |
1524 | for (Tindex i = 0; i < N; i++) { |
1525 | const Tindex index = internal::SubtleMustCopy(indices_vec(i)); |
1526 | OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), |
1527 | errors::InvalidArgument( |
1528 | strings::StrCat("Index " , index, " at offset " , i, |
1529 | " in indices is out of range" ))); |
1530 | const T& g = grad_flat(i); |
1531 | auto learning_rate = lr_scalar; |
1532 | auto prox_v = var_flat(index); |
1533 | prox_v -= learning_rate * g; |
1534 | if (l1_scalar > 0) { |
1535 | var_flat(index) = |
1536 | sgn(prox_v) * |
1537 | std::max(std::abs(prox_v) - learning_rate * l1_scalar, |
1538 | static_cast<T>(0.0)) / |
1539 | (1.0 + l2_scalar * learning_rate); |
1540 | } else { |
1541 | var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate); |
1542 | } |
1543 | } |
1544 | } |
1545 | } |
1546 | |
1547 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1548 | } |
1549 | |
1550 | private: |
1551 | bool use_exclusive_lock_; |
1552 | }; |
1553 | |
1554 | #define REGISTER_KERNELS(T, Tindices) \ |
1555 | REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \ |
1556 | .Device(DEVICE_CPU) \ |
1557 | .TypeConstraint<T>("T") \ |
1558 | .TypeConstraint<Tindices>("Tindices"), \ |
1559 | SparseApplyProximalGradientDescentOp<T, Tindices>); \ |
1560 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalGradientDescent") \ |
1561 | .Device(DEVICE_CPU) \ |
1562 | .TypeConstraint<T>("T") \ |
1563 | .TypeConstraint<Tindices>("Tindices"), \ |
1564 | SparseApplyProximalGradientDescentOp<T, Tindices>); |
1565 | |
1566 | REGISTER_KERNELS(float, int32); |
1567 | REGISTER_KERNELS(float, int64_t); |
1568 | REGISTER_KERNELS(double, int32); |
1569 | REGISTER_KERNELS(double, int64_t); |
1570 | #undef REGISTER_KERNELS |
1571 | |
1572 | template <typename Device, typename T> |
1573 | class ApplyAdagradOp : public OpKernel { |
1574 | public: |
1575 | explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1576 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1577 | OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots" , &update_slots_)); |
1578 | } |
1579 | |
1580 | void Compute(OpKernelContext* ctx) override { |
1581 | const bool sparse = false; |
1582 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
1583 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
1584 | Tensor var; |
1585 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1586 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1587 | Tensor accum; |
1588 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1589 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
1590 | OP_REQUIRES( |
1591 | ctx, var.IsInitialized(), |
1592 | errors::FailedPrecondition( |
1593 | "Attempting to use uninitialized variables: " , requested_input(0))); |
1594 | OP_REQUIRES( |
1595 | ctx, accum.IsInitialized(), |
1596 | errors::FailedPrecondition( |
1597 | "Attempting to use uninitialized variables: " , requested_input(1))); |
1598 | const Tensor& lr = ctx->input(2); |
1599 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
1600 | errors::InvalidArgument("lr is not a scalar: " , |
1601 | lr.shape().DebugString())); |
1602 | const Tensor& grad = ctx->input(3); |
1603 | OP_REQUIRES( |
1604 | ctx, var.shape().IsSameSize(accum.shape()), |
1605 | errors::InvalidArgument("var and accum do not have the same shape" , |
1606 | var.shape().DebugString(), " " , |
1607 | accum.shape().DebugString())); |
1608 | OP_REQUIRES( |
1609 | ctx, var.shape().IsSameSize(grad.shape()), |
1610 | errors::InvalidArgument("var and grad do not have the same shape" , |
1611 | var.shape().DebugString(), " " , |
1612 | grad.shape().DebugString())); |
1613 | |
1614 | const Device& device = ctx->template eigen_device<Device>(); |
1615 | functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(), |
1616 | lr.scalar<T>(), grad.flat<T>(), |
1617 | update_slots_); |
1618 | |
1619 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1620 | } |
1621 | |
1622 | private: |
1623 | bool use_exclusive_lock_; |
1624 | bool update_slots_; |
1625 | }; |
1626 | |
1627 | #define REGISTER_KERNELS(D, T) \ |
1628 | REGISTER_KERNEL_BUILDER( \ |
1629 | Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
1630 | ApplyAdagradOp<D##Device, T>); \ |
1631 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagrad") \ |
1632 | .HostMemory("var") \ |
1633 | .HostMemory("accum") \ |
1634 | .Device(DEVICE_##D) \ |
1635 | .TypeConstraint<T>("T"), \ |
1636 | ApplyAdagradOp<D##Device, T>); |
1637 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
1638 | |
1639 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
1640 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
1641 | |
1642 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1643 | // Forward declarations of the functor specializations for GPU. |
1644 | namespace functor { |
1645 | #define DECLARE_GPU_SPEC(T) \ |
1646 | template <> \ |
1647 | void ApplyAdagrad<GPUDevice, T>::operator()( \ |
1648 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
1649 | typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ |
1650 | typename TTypes<T>::ConstFlat grad, bool update_slots); \ |
1651 | extern template struct ApplyAdagrad<GPUDevice, T>; |
1652 | DECLARE_GPU_SPEC(Eigen::half); |
1653 | DECLARE_GPU_SPEC(float); |
1654 | DECLARE_GPU_SPEC(double); |
1655 | DECLARE_GPU_SPEC(complex64); |
1656 | DECLARE_GPU_SPEC(complex128); |
1657 | #undef DECLARE_GPU_SPEC |
1658 | } // namespace functor |
1659 | |
1660 | REGISTER_KERNELS(GPU, Eigen::half); |
1661 | REGISTER_KERNELS(GPU, float); |
1662 | REGISTER_KERNELS(GPU, double); |
1663 | REGISTER_KERNELS(GPU, complex64); |
1664 | REGISTER_KERNELS(GPU, complex128); |
1665 | #endif |
1666 | #undef REGISTER_CPU_KERNELS |
1667 | #undef REGISTER_KERNELS |
1668 | |
1669 | template <typename Device, typename T> |
1670 | class ApplyAdagradV2Op : public OpKernel { |
1671 | public: |
1672 | explicit ApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1673 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1674 | OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots" , &update_slots_)); |
1675 | } |
1676 | |
1677 | void Compute(OpKernelContext* ctx) override { |
1678 | const bool sparse = false; |
1679 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
1680 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
1681 | Tensor var; |
1682 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1683 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1684 | Tensor accum; |
1685 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1686 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
1687 | OP_REQUIRES( |
1688 | ctx, var.IsInitialized(), |
1689 | errors::FailedPrecondition( |
1690 | "Attempting to use uninitialized variables: " , requested_input(0))); |
1691 | OP_REQUIRES( |
1692 | ctx, accum.IsInitialized(), |
1693 | errors::FailedPrecondition( |
1694 | "Attempting to use uninitialized variables: " , requested_input(1))); |
1695 | const Tensor& lr = ctx->input(2); |
1696 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
1697 | errors::InvalidArgument("lr is not a scalar: " , |
1698 | lr.shape().DebugString())); |
1699 | const Tensor& epsilon = ctx->input(3); |
1700 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
1701 | errors::InvalidArgument("epsilon is not a scalar: " , |
1702 | epsilon.shape().DebugString())); |
1703 | const Tensor& grad = ctx->input(4); |
1704 | OP_REQUIRES( |
1705 | ctx, var.shape().IsSameSize(accum.shape()), |
1706 | errors::InvalidArgument("var and accum do not have the same shape" , |
1707 | var.shape().DebugString(), " " , |
1708 | accum.shape().DebugString())); |
1709 | OP_REQUIRES( |
1710 | ctx, var.shape().IsSameSize(grad.shape()), |
1711 | errors::InvalidArgument("var and grad do not have the same shape" , |
1712 | var.shape().DebugString(), " " , |
1713 | grad.shape().DebugString())); |
1714 | |
1715 | const Device& device = ctx->template eigen_device<Device>(); |
1716 | functor::ApplyAdagradV2<Device, T>()(device, var.flat<T>(), accum.flat<T>(), |
1717 | lr.scalar<T>(), epsilon.scalar<T>(), |
1718 | grad.flat<T>(), update_slots_); |
1719 | |
1720 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1721 | } |
1722 | |
1723 | private: |
1724 | bool use_exclusive_lock_; |
1725 | bool update_slots_; |
1726 | }; |
1727 | |
1728 | #define REGISTER_KERNELS(D, T) \ |
1729 | REGISTER_KERNEL_BUILDER( \ |
1730 | Name("ApplyAdagradV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
1731 | ApplyAdagradV2Op<D##Device, T>); \ |
1732 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradV2") \ |
1733 | .HostMemory("var") \ |
1734 | .HostMemory("accum") \ |
1735 | .Device(DEVICE_##D) \ |
1736 | .TypeConstraint<T>("T"), \ |
1737 | ApplyAdagradV2Op<D##Device, T>); |
1738 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
1739 | |
1740 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
1741 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
1742 | |
1743 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1744 | // Forward declarations of the functor specializations for GPU. |
1745 | namespace functor { |
1746 | #define DECLARE_GPU_SPEC(T) \ |
1747 | template <> \ |
1748 | void ApplyAdagradV2<GPUDevice, T>::operator()( \ |
1749 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
1750 | typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ |
1751 | typename TTypes<T>::ConstScalar epsilon, \ |
1752 | typename TTypes<T>::ConstFlat grad, bool update_slots); \ |
1753 | extern template struct ApplyAdagradV2<GPUDevice, T>; |
1754 | DECLARE_GPU_SPEC(Eigen::half); |
1755 | DECLARE_GPU_SPEC(float); |
1756 | DECLARE_GPU_SPEC(double); |
1757 | DECLARE_GPU_SPEC(complex64); |
1758 | DECLARE_GPU_SPEC(complex128); |
1759 | #undef DECLARE_GPU_SPEC |
1760 | } // namespace functor |
1761 | |
1762 | REGISTER_KERNELS(GPU, Eigen::half); |
1763 | REGISTER_KERNELS(GPU, float); |
1764 | REGISTER_KERNELS(GPU, double); |
1765 | REGISTER_KERNELS(GPU, complex64); |
1766 | REGISTER_KERNELS(GPU, complex128); |
1767 | #endif |
1768 | #undef REGISTER_CPU_KERNELS |
1769 | #undef REGISTER_KERNELS |
1770 | |
1771 | template <typename Device, typename T> |
1772 | class ApplyProximalAdagradOp : public OpKernel { |
1773 | public: |
1774 | explicit ApplyProximalAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1775 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1776 | } |
1777 | |
1778 | void Compute(OpKernelContext* ctx) override { |
1779 | const bool sparse = false; |
1780 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
1781 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
1782 | Tensor var; |
1783 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1784 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1785 | Tensor accum; |
1786 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1787 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
1788 | OP_REQUIRES( |
1789 | ctx, var.IsInitialized(), |
1790 | errors::FailedPrecondition( |
1791 | "Attempting to use uninitialized variables: " , requested_input(0))); |
1792 | OP_REQUIRES( |
1793 | ctx, accum.IsInitialized(), |
1794 | errors::FailedPrecondition( |
1795 | "Attempting to use uninitialized variables: " , requested_input(1))); |
1796 | OP_REQUIRES( |
1797 | ctx, var.shape().IsSameSize(accum.shape()), |
1798 | errors::InvalidArgument("var and accum do not have the same shape" , |
1799 | var.shape().DebugString(), " " , |
1800 | accum.shape().DebugString())); |
1801 | const Tensor& lr = ctx->input(2); |
1802 | OP_REQUIRES(ctx, |
1803 | TensorShapeUtils::IsScalar(lr.shape()) && |
1804 | (!std::is_same<Device, CPUDevice>::value || |
1805 | lr.scalar<T>()() > static_cast<T>(0)), |
1806 | errors::InvalidArgument("lr is not a positive scalar: " , |
1807 | lr.shape().DebugString())); |
1808 | const Tensor& l1 = ctx->input(3); |
1809 | OP_REQUIRES(ctx, |
1810 | TensorShapeUtils::IsScalar(l1.shape()) && |
1811 | (!std::is_same<Device, CPUDevice>::value || |
1812 | l1.scalar<T>()() >= static_cast<T>(0)), |
1813 | errors::InvalidArgument("l1 regularization strength is not a " |
1814 | "non-negative scalar: " , |
1815 | l1.shape().DebugString())); |
1816 | const Tensor& l2 = ctx->input(4); |
1817 | OP_REQUIRES(ctx, |
1818 | TensorShapeUtils::IsScalar(l2.shape()) && |
1819 | (!std::is_same<Device, CPUDevice>::value || |
1820 | l2.scalar<T>()() >= static_cast<T>(0)), |
1821 | errors::InvalidArgument("l2 regularization strength is not a " |
1822 | "non-negative scalar: " , |
1823 | l2.shape().DebugString())); |
1824 | const Tensor& grad = ctx->input(5); |
1825 | OP_REQUIRES( |
1826 | ctx, var.shape().IsSameSize(grad.shape()), |
1827 | errors::InvalidArgument("var and grad do not have the same shape" , |
1828 | var.shape().DebugString(), " " , |
1829 | grad.shape().DebugString())); |
1830 | |
1831 | const Device& device = ctx->template eigen_device<Device>(); |
1832 | functor::ApplyProximalAdagrad<Device, T>()( |
1833 | device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), |
1834 | l2.scalar<T>(), grad.flat<T>()); |
1835 | |
1836 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1837 | } |
1838 | |
1839 | private: |
1840 | bool use_exclusive_lock_; |
1841 | }; |
1842 | |
1843 | #define REGISTER_KERNELS(D, T) \ |
1844 | REGISTER_KERNEL_BUILDER( \ |
1845 | Name("ApplyProximalAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
1846 | ApplyProximalAdagradOp<D##Device, T>); \ |
1847 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalAdagrad") \ |
1848 | .Device(DEVICE_##D) \ |
1849 | .HostMemory("var") \ |
1850 | .HostMemory("accum") \ |
1851 | .TypeConstraint<T>("T"), \ |
1852 | ApplyProximalAdagradOp<D##Device, T>); |
1853 | |
1854 | REGISTER_KERNELS(CPU, float); |
1855 | REGISTER_KERNELS(CPU, double); |
1856 | |
1857 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1858 | // Forward declarations of the functor specializations for GPU. |
1859 | namespace functor { |
1860 | #define DECLARE_GPU_SPEC(T) \ |
1861 | template <> \ |
1862 | void ApplyProximalAdagrad<GPUDevice, T>::operator()( \ |
1863 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
1864 | typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ |
1865 | typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \ |
1866 | typename TTypes<T>::ConstFlat grad); \ |
1867 | extern template struct ApplyProximalAdagrad<GPUDevice, T>; |
1868 | DECLARE_GPU_SPEC(Eigen::half); |
1869 | DECLARE_GPU_SPEC(float); |
1870 | DECLARE_GPU_SPEC(double); |
1871 | #undef DECLARE_GPU_SPEC |
1872 | } // namespace functor |
1873 | |
1874 | REGISTER_KERNELS(GPU, Eigen::half); |
1875 | REGISTER_KERNELS(GPU, float); |
1876 | REGISTER_KERNELS(GPU, double); |
1877 | #endif |
1878 | #undef REGISTER_CPU_KERNELS |
1879 | #undef REGISTER_KERNELS |
1880 | |
1881 | template <typename Device, typename T, typename Tindex> |
1882 | class SparseApplyAdagradOp : public OpKernel { |
1883 | public: |
1884 | explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1885 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
1886 | OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots" , &update_slots_)); |
1887 | } |
1888 | |
1889 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
1890 | const bool sparse = true; |
1891 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
1892 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
1893 | Tensor var; |
1894 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1895 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
1896 | Tensor accum; |
1897 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
1898 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
1899 | OP_REQUIRES( |
1900 | ctx, var.IsInitialized(), |
1901 | errors::FailedPrecondition( |
1902 | "Attempting to use uninitialized variables: " , requested_input(0))); |
1903 | OP_REQUIRES( |
1904 | ctx, accum.IsInitialized(), |
1905 | errors::FailedPrecondition( |
1906 | "Attempting to use uninitialized variables: " , requested_input(1))); |
1907 | OP_REQUIRES( |
1908 | ctx, var.shape().IsSameSize(accum.shape()), |
1909 | errors::InvalidArgument("var and accum do not have the same shape" , |
1910 | var.shape().DebugString(), " " , |
1911 | accum.shape().DebugString())); |
1912 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
1913 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
1914 | |
1915 | const Tensor& lr = ctx->input(2); |
1916 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
1917 | errors::InvalidArgument("lr is not a scalar: " , |
1918 | lr.shape().DebugString())); |
1919 | const Tensor& grad = ctx->input(3); |
1920 | const Tensor& indices = ctx->input(4); |
1921 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
1922 | errors::InvalidArgument("indices must be one-dimensional" )); |
1923 | |
1924 | int64_t inner_dim = 1; |
1925 | for (int d = 1; d < var.dims(); d++) { |
1926 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
1927 | errors::InvalidArgument(strings::StrCat( |
1928 | "var and grad must match in dimension " , d))); |
1929 | inner_dim *= grad.dim_size(d); |
1930 | } |
1931 | const Tindex N = indices.dim_size(0); |
1932 | OP_REQUIRES( |
1933 | ctx, grad.dim_size(0) == N, |
1934 | errors::InvalidArgument( |
1935 | "grad must be the same size as indices in the first dimension." )); |
1936 | |
1937 | OP_REQUIRES(ctx, inner_dim > 0, |
1938 | errors::InvalidArgument( |
1939 | "Inner dimension should be greater than zero." )); |
1940 | |
1941 | const Device& device = ctx->template eigen_device<Device>(); |
1942 | OP_REQUIRES_OK( |
1943 | ctx, functor::SparseApplyAdagrad<Device, T, Tindex, |
1944 | /*has_epsilon = */ false>()( |
1945 | device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(), |
1946 | // Note: Passing lr as a placeholder for unused epsilon. |
1947 | lr.scalar<T>(), lr.scalar<T>(), grad.flat_outer_dims<T>(), |
1948 | indices.vec<Tindex>(), inner_dim, update_slots_)); |
1949 | |
1950 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
1951 | } |
1952 | |
1953 | private: |
1954 | bool use_exclusive_lock_; |
1955 | bool update_slots_; |
1956 | }; |
1957 | |
1958 | #define REGISTER_KERNELS(D, T, Tindices) \ |
1959 | REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad") \ |
1960 | .Device(DEVICE_##D) \ |
1961 | .TypeConstraint<T>("T") \ |
1962 | .TypeConstraint<Tindices>("Tindices"), \ |
1963 | SparseApplyAdagradOp<D##Device, T, Tindices>); \ |
1964 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad") \ |
1965 | .Device(DEVICE_##D) \ |
1966 | .TypeConstraint<T>("T") \ |
1967 | .TypeConstraint<Tindices>("Tindices"), \ |
1968 | SparseApplyAdagradOp<D##Device, T, Tindices>); |
1969 | #define REGISTER_CPU_KERNELS(T) \ |
1970 | REGISTER_KERNELS(CPU, T, int32); \ |
1971 | REGISTER_KERNELS(CPU, T, int64_t); |
1972 | |
1973 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
1974 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
1975 | |
1976 | #undef REGISTER_CPU_KERNELS |
1977 | |
1978 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1979 | // Forward declarations of the functor specializations for GPU. |
1980 | namespace functor { |
1981 | #define DECLARE_GPU_SPEC(T, Tindex) \ |
1982 | template <> \ |
1983 | Status \ |
1984 | SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/false>::operator()( \ |
1985 | const GPUDevice& d, typename TTypes<T>::Matrix var, \ |
1986 | typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \ |
1987 | typename TTypes<T>::ConstScalar epsilon, \ |
1988 | typename TTypes<T>::ConstMatrix grad, \ |
1989 | typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \ |
1990 | bool update_slots); \ |
1991 | extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \ |
1992 | /*has_epsilon=*/false>; |
1993 | DECLARE_GPU_SPEC(Eigen::half, int32); |
1994 | DECLARE_GPU_SPEC(Eigen::half, int64_t); |
1995 | DECLARE_GPU_SPEC(float, int32); |
1996 | DECLARE_GPU_SPEC(float, int64_t); |
1997 | DECLARE_GPU_SPEC(double, int32); |
1998 | DECLARE_GPU_SPEC(double, int64_t); |
1999 | #undef DECLARE_GPU_SPEC |
2000 | } // namespace functor |
2001 | |
2002 | REGISTER_KERNELS(GPU, Eigen::half, int32); |
2003 | REGISTER_KERNELS(GPU, Eigen::half, int64_t); |
2004 | REGISTER_KERNELS(GPU, float, int32); |
2005 | REGISTER_KERNELS(GPU, float, int64_t); |
2006 | REGISTER_KERNELS(GPU, double, int32); |
2007 | REGISTER_KERNELS(GPU, double, int64_t); |
2008 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2009 | #undef REGISTER_KERNELS |
2010 | |
2011 | template <typename Device, typename T, typename Tindex> |
2012 | class SparseApplyAdagradV2Op : public OpKernel { |
2013 | public: |
2014 | explicit SparseApplyAdagradV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { |
2015 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
2016 | OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots" , &update_slots_)); |
2017 | } |
2018 | |
2019 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
2020 | const bool sparse = true; |
2021 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
2022 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
2023 | Tensor var; |
2024 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2025 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
2026 | Tensor accum; |
2027 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2028 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
2029 | OP_REQUIRES( |
2030 | ctx, var.IsInitialized(), |
2031 | errors::FailedPrecondition( |
2032 | "Attempting to use uninitialized variables: " , requested_input(0))); |
2033 | OP_REQUIRES( |
2034 | ctx, accum.IsInitialized(), |
2035 | errors::FailedPrecondition( |
2036 | "Attempting to use uninitialized variables: " , requested_input(1))); |
2037 | OP_REQUIRES( |
2038 | ctx, var.shape().IsSameSize(accum.shape()), |
2039 | errors::InvalidArgument("var and accum do not have the same shape" , |
2040 | var.shape().DebugString(), " " , |
2041 | accum.shape().DebugString())); |
2042 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
2043 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
2044 | |
2045 | const Tensor& lr = ctx->input(2); |
2046 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
2047 | errors::InvalidArgument("lr is not a scalar: " , |
2048 | lr.shape().DebugString())); |
2049 | const Tensor& epsilon = ctx->input(3); |
2050 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
2051 | errors::InvalidArgument("epsilon is not a scalar: " , |
2052 | epsilon.shape().DebugString())); |
2053 | const Tensor& grad = ctx->input(4); |
2054 | const Tensor& indices = ctx->input(5); |
2055 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
2056 | errors::InvalidArgument("indices must be one-dimensional" )); |
2057 | |
2058 | int64_t inner_dim = 1; |
2059 | for (int d = 1; d < var.dims(); d++) { |
2060 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
2061 | errors::InvalidArgument(strings::StrCat( |
2062 | "var and grad must match in dimension " , d))); |
2063 | inner_dim *= grad.dim_size(d); |
2064 | } |
2065 | const Tindex N = indices.dim_size(0); |
2066 | OP_REQUIRES( |
2067 | ctx, grad.dim_size(0) == N, |
2068 | errors::InvalidArgument( |
2069 | "grad must be the same size as indices in the first dimension." )); |
2070 | |
2071 | OP_REQUIRES(ctx, inner_dim > 0, |
2072 | errors::InvalidArgument( |
2073 | "Inner dimension should be greater than zero." )); |
2074 | |
2075 | const Device& device = ctx->template eigen_device<Device>(); |
2076 | OP_REQUIRES_OK( |
2077 | ctx, functor::SparseApplyAdagrad<Device, T, Tindex, |
2078 | /*has_epsilon = */ true>()( |
2079 | device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(), |
2080 | lr.scalar<T>(), epsilon.scalar<T>(), grad.flat_outer_dims<T>(), |
2081 | indices.vec<Tindex>(), inner_dim, update_slots_)); |
2082 | |
2083 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
2084 | } |
2085 | |
2086 | private: |
2087 | bool use_exclusive_lock_; |
2088 | bool update_slots_; |
2089 | }; |
2090 | |
2091 | #define REGISTER_KERNELS(D, T, Tindices) \ |
2092 | REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradV2") \ |
2093 | .Device(DEVICE_##D) \ |
2094 | .TypeConstraint<T>("T") \ |
2095 | .TypeConstraint<Tindices>("Tindices"), \ |
2096 | SparseApplyAdagradV2Op<D##Device, T, Tindices>); \ |
2097 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradV2") \ |
2098 | .Device(DEVICE_##D) \ |
2099 | .TypeConstraint<T>("T") \ |
2100 | .TypeConstraint<Tindices>("Tindices"), \ |
2101 | SparseApplyAdagradV2Op<D##Device, T, Tindices>); |
2102 | #define REGISTER_CPU_KERNELS(T) \ |
2103 | REGISTER_KERNELS(CPU, T, int32); \ |
2104 | REGISTER_KERNELS(CPU, T, int64_t); |
2105 | |
2106 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
2107 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
2108 | |
2109 | #undef REGISTER_CPU_KERNELS |
2110 | |
2111 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2112 | // Forward declarations of the functor specializations for GPU. |
2113 | namespace functor { |
2114 | #define DECLARE_GPU_SPEC(T, Tindex) \ |
2115 | template <> \ |
2116 | Status \ |
2117 | SparseApplyAdagrad<GPUDevice, T, Tindex, /*has_epsilon=*/true>::operator()( \ |
2118 | const GPUDevice& d, typename TTypes<T>::Matrix var, \ |
2119 | typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \ |
2120 | typename TTypes<T>::ConstScalar epsilon, \ |
2121 | typename TTypes<T>::ConstMatrix grad, \ |
2122 | typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \ |
2123 | bool update_slots); \ |
2124 | extern template struct SparseApplyAdagrad<GPUDevice, T, Tindex, \ |
2125 | /*has_epsilon=*/true>; |
2126 | DECLARE_GPU_SPEC(Eigen::half, int32); |
2127 | DECLARE_GPU_SPEC(Eigen::half, int64_t); |
2128 | DECLARE_GPU_SPEC(float, int32); |
2129 | DECLARE_GPU_SPEC(float, int64_t); |
2130 | DECLARE_GPU_SPEC(double, int32); |
2131 | DECLARE_GPU_SPEC(double, int64_t); |
2132 | #undef DECLARE_GPU_SPEC |
2133 | } // namespace functor |
2134 | |
2135 | REGISTER_KERNELS(GPU, Eigen::half, int32); |
2136 | REGISTER_KERNELS(GPU, Eigen::half, int64_t); |
2137 | REGISTER_KERNELS(GPU, float, int32); |
2138 | REGISTER_KERNELS(GPU, float, int64_t); |
2139 | REGISTER_KERNELS(GPU, double, int32); |
2140 | REGISTER_KERNELS(GPU, double, int64_t); |
2141 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2142 | #undef REGISTER_KERNELS |
2143 | |
2144 | template <typename Device, typename T, typename Tindex> |
2145 | class SparseApplyProximalAdagradOp : public OpKernel { |
2146 | public: |
2147 | explicit SparseApplyProximalAdagradOp(OpKernelConstruction* ctx) |
2148 | : OpKernel(ctx) { |
2149 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
2150 | } |
2151 | |
2152 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
2153 | const bool sparse = true; |
2154 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
2155 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
2156 | Tensor var; |
2157 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2158 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
2159 | Tensor accum; |
2160 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2161 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
2162 | OP_REQUIRES( |
2163 | ctx, var.IsInitialized(), |
2164 | errors::FailedPrecondition( |
2165 | "Attempting to use uninitialized variables: " , requested_input(0))); |
2166 | OP_REQUIRES( |
2167 | ctx, accum.IsInitialized(), |
2168 | errors::FailedPrecondition( |
2169 | "Attempting to use uninitialized variables: " , requested_input(1))); |
2170 | OP_REQUIRES( |
2171 | ctx, var.shape().IsSameSize(accum.shape()), |
2172 | errors::InvalidArgument("var and accum do not have the same shape" , |
2173 | var.shape().DebugString(), " " , |
2174 | accum.shape().DebugString())); |
2175 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
2176 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
2177 | |
2178 | const Tensor& lr = ctx->input(2); |
2179 | OP_REQUIRES(ctx, |
2180 | TensorShapeUtils::IsScalar(lr.shape()) && |
2181 | (!std::is_same<Device, CPUDevice>::value || |
2182 | lr.scalar<T>()() > static_cast<T>(0)), |
2183 | errors::InvalidArgument("lr is not a positive scalar: " , |
2184 | lr.shape().DebugString())); |
2185 | const Tensor& l1 = ctx->input(3); |
2186 | OP_REQUIRES(ctx, |
2187 | TensorShapeUtils::IsScalar(l1.shape()) && |
2188 | (!std::is_same<Device, CPUDevice>::value || |
2189 | l1.scalar<T>()() >= static_cast<T>(0)), |
2190 | errors::InvalidArgument("l1 regularization strength is not a " |
2191 | "non-negative scalar: " , |
2192 | l1.shape().DebugString())); |
2193 | const Tensor& l2 = ctx->input(4); |
2194 | OP_REQUIRES(ctx, |
2195 | TensorShapeUtils::IsScalar(l2.shape()) && |
2196 | (!std::is_same<Device, CPUDevice>::value || |
2197 | l2.scalar<T>()() >= static_cast<T>(0)), |
2198 | errors::InvalidArgument("l2 regularization strength is not a " |
2199 | "non-negative scalar: " , |
2200 | l2.shape().DebugString())); |
2201 | |
2202 | const Tensor& grad = ctx->input(5); |
2203 | const Tensor& indices = ctx->input(6); |
2204 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
2205 | errors::InvalidArgument("indices must be one-dimensional" )); |
2206 | |
2207 | int64_t inner_dim = 1; |
2208 | for (int d = 1; d < var.dims(); d++) { |
2209 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
2210 | errors::InvalidArgument(strings::StrCat( |
2211 | "var and grad must match in dimension " , d))); |
2212 | inner_dim *= grad.dim_size(d); |
2213 | } |
2214 | const Tindex N = indices.dim_size(0); |
2215 | OP_REQUIRES( |
2216 | ctx, grad.dim_size(0) == N, |
2217 | errors::InvalidArgument( |
2218 | "grad must be the same size as indices in the first dimension." )); |
2219 | |
2220 | OP_REQUIRES(ctx, inner_dim > 0, |
2221 | errors::InvalidArgument( |
2222 | "Inner dimension should be greater than zero." )); |
2223 | |
2224 | const Device& device = ctx->template eigen_device<Device>(); |
2225 | OP_REQUIRES_OK( |
2226 | ctx, functor::SparseApplyProximalAdagrad<Device, T, Tindex>()( |
2227 | device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(), |
2228 | lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), |
2229 | grad.flat_outer_dims<T>(), indices.vec<Tindex>(), inner_dim)); |
2230 | |
2231 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
2232 | } |
2233 | |
2234 | private: |
2235 | bool use_exclusive_lock_; |
2236 | }; |
2237 | |
2238 | #define REGISTER_KERNELS(D, T, Tindices) \ |
2239 | REGISTER_KERNEL_BUILDER( \ |
2240 | Name("SparseApplyProximalAdagrad") \ |
2241 | .Device(DEVICE_##D) \ |
2242 | .TypeConstraint<T>("T") \ |
2243 | .TypeConstraint<Tindices>("Tindices"), \ |
2244 | SparseApplyProximalAdagradOp<D##Device, T, Tindices>); \ |
2245 | REGISTER_KERNEL_BUILDER( \ |
2246 | Name("ResourceSparseApplyProximalAdagrad") \ |
2247 | .Device(DEVICE_##D) \ |
2248 | .TypeConstraint<T>("T") \ |
2249 | .TypeConstraint<Tindices>("Tindices"), \ |
2250 | SparseApplyProximalAdagradOp<D##Device, T, Tindices>); |
2251 | |
2252 | REGISTER_KERNELS(CPU, float, int32); |
2253 | REGISTER_KERNELS(CPU, float, int64_t); |
2254 | REGISTER_KERNELS(CPU, double, int32); |
2255 | REGISTER_KERNELS(CPU, double, int64_t); |
2256 | |
2257 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2258 | // Forward declarations of the functor specializations for GPU. |
2259 | namespace functor { |
2260 | #define DECLARE_GPU_SPEC(T, Tindex) \ |
2261 | template <> \ |
2262 | Status SparseApplyProximalAdagrad<GPUDevice, T, Tindex>::operator()( \ |
2263 | const GPUDevice& d, typename TTypes<T>::Matrix var, \ |
2264 | typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \ |
2265 | typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \ |
2266 | typename TTypes<T>::ConstMatrix grad, \ |
2267 | typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim); \ |
2268 | extern template struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex>; |
2269 | DECLARE_GPU_SPEC(Eigen::half, int32); |
2270 | DECLARE_GPU_SPEC(Eigen::half, int64_t); |
2271 | DECLARE_GPU_SPEC(float, int32); |
2272 | DECLARE_GPU_SPEC(float, int64_t); |
2273 | DECLARE_GPU_SPEC(double, int32); |
2274 | DECLARE_GPU_SPEC(double, int64_t); |
2275 | #undef DECLARE_GPU_SPEC |
2276 | } // namespace functor |
2277 | |
2278 | REGISTER_KERNELS(GPU, Eigen::half, int32); |
2279 | REGISTER_KERNELS(GPU, Eigen::half, int64_t); |
2280 | REGISTER_KERNELS(GPU, float, int32); |
2281 | REGISTER_KERNELS(GPU, float, int64_t); |
2282 | REGISTER_KERNELS(GPU, double, int32); |
2283 | REGISTER_KERNELS(GPU, double, int64_t); |
2284 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2285 | #undef REGISTER_KERNELS |
2286 | |
2287 | template <typename Device, typename T> |
2288 | class ApplyAdagradDAOp : public OpKernel { |
2289 | public: |
2290 | explicit ApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
2291 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
2292 | } |
2293 | |
2294 | void Compute(OpKernelContext* ctx) override { |
2295 | const bool sparse = false; |
2296 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
2297 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
2298 | Tensor var; |
2299 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2300 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
2301 | Tensor gradient_accum; |
2302 | OP_REQUIRES_OK( |
2303 | ctx, GetInputTensorFromVariable<Device, T>(ctx, 1, use_exclusive_lock_, |
2304 | sparse, &gradient_accum)); |
2305 | Tensor gradient_squared_accum; |
2306 | OP_REQUIRES_OK( |
2307 | ctx, GetInputTensorFromVariable<Device, T>( |
2308 | ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum)); |
2309 | OP_REQUIRES( |
2310 | ctx, var.IsInitialized(), |
2311 | errors::FailedPrecondition( |
2312 | "Attempting to use uninitialized variables: " , requested_input(0))); |
2313 | OP_REQUIRES( |
2314 | ctx, gradient_accum.IsInitialized(), |
2315 | errors::FailedPrecondition( |
2316 | "Attempting to use uninitialized variables: " , requested_input(1))); |
2317 | OP_REQUIRES( |
2318 | ctx, gradient_squared_accum.IsInitialized(), |
2319 | errors::FailedPrecondition( |
2320 | "Attempting to use uninitialized variables: " , requested_input(2))); |
2321 | OP_REQUIRES( |
2322 | ctx, var.shape().IsSameSize(gradient_accum.shape()), |
2323 | errors::InvalidArgument("var and accum do not have the same shape" , |
2324 | var.shape().DebugString(), " " , |
2325 | gradient_accum.shape().DebugString())); |
2326 | OP_REQUIRES( |
2327 | ctx, var.shape().IsSameSize(gradient_squared_accum.shape()), |
2328 | errors::InvalidArgument("var and accum do not have the same shape" , |
2329 | var.shape().DebugString(), " " , |
2330 | gradient_squared_accum.shape().DebugString())); |
2331 | |
2332 | const Tensor& grad = ctx->input(3); |
2333 | OP_REQUIRES( |
2334 | ctx, var.shape().IsSameSize(grad.shape()), |
2335 | errors::InvalidArgument("var and grad do not have the same shape" , |
2336 | var.shape().DebugString(), " " , |
2337 | grad.shape().DebugString())); |
2338 | |
2339 | const Tensor& lr = ctx->input(4); |
2340 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
2341 | errors::InvalidArgument("lr is not a scalar: " , |
2342 | lr.shape().DebugString())); |
2343 | const Tensor& l1 = ctx->input(5); |
2344 | OP_REQUIRES( |
2345 | ctx, TensorShapeUtils::IsScalar(l1.shape()), |
2346 | errors::InvalidArgument("l1 regularization strength is not a scalar: " , |
2347 | l1.shape().DebugString())); |
2348 | const Tensor& l2 = ctx->input(6); |
2349 | OP_REQUIRES( |
2350 | ctx, TensorShapeUtils::IsScalar(l2.shape()), |
2351 | errors::InvalidArgument("l2 regularization strength is not a scalar: " , |
2352 | l2.shape().DebugString())); |
2353 | const Tensor& global_step = ctx->input(7); |
2354 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()), |
2355 | errors::InvalidArgument("global_step is not a scalar: " , |
2356 | global_step.shape().DebugString())); |
2357 | |
2358 | const Device& device = ctx->template eigen_device<Device>(); |
2359 | functor::ApplyAdagradDA<Device, T>()( |
2360 | device, var.flat<T>(), gradient_accum.flat<T>(), |
2361 | gradient_squared_accum.flat<T>(), lr.scalar<T>(), |
2362 | global_step.scalar<int64_t>()(), l1.scalar<T>(), l2.scalar<T>(), |
2363 | grad.flat<T>()); |
2364 | |
2365 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
2366 | } |
2367 | |
2368 | private: |
2369 | bool use_exclusive_lock_; |
2370 | }; |
2371 | |
2372 | #define REGISTER_KERNELS(D, T) \ |
2373 | REGISTER_KERNEL_BUILDER( \ |
2374 | Name("ApplyAdagradDA").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
2375 | ApplyAdagradDAOp<D##Device, T>); \ |
2376 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradDA") \ |
2377 | .Device(DEVICE_##D) \ |
2378 | .HostMemory("var") \ |
2379 | .HostMemory("gradient_accumulator") \ |
2380 | .HostMemory("gradient_squared_accumulator") \ |
2381 | .TypeConstraint<T>("T"), \ |
2382 | ApplyAdagradDAOp<D##Device, T>); |
2383 | |
2384 | REGISTER_KERNELS(CPU, float); |
2385 | REGISTER_KERNELS(CPU, double); |
2386 | #undef REGISTER_KERNELS |
2387 | |
2388 | // Note, this op works on cpu only. |
2389 | template <typename T, typename Tindex> |
2390 | class SparseApplyAdagradDAOp : public OpKernel { |
2391 | public: |
2392 | explicit SparseApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
2393 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
2394 | } |
2395 | |
2396 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
2397 | const bool sparse = true; |
2398 | auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>( |
2399 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
2400 | Tensor var; |
2401 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
2402 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
2403 | Tensor gradient_accum; |
2404 | OP_REQUIRES_OK(ctx, |
2405 | GetInputTensorFromVariable<CPUDevice, T>( |
2406 | ctx, 1, use_exclusive_lock_, sparse, &gradient_accum)); |
2407 | Tensor gradient_squared_accum; |
2408 | OP_REQUIRES_OK( |
2409 | ctx, GetInputTensorFromVariable<CPUDevice, T>( |
2410 | ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum)); |
2411 | OP_REQUIRES( |
2412 | ctx, var.IsInitialized(), |
2413 | errors::FailedPrecondition( |
2414 | "Attempting to use uninitialized variables: " , requested_input(0))); |
2415 | OP_REQUIRES( |
2416 | ctx, gradient_accum.IsInitialized(), |
2417 | errors::FailedPrecondition( |
2418 | "Attempting to use uninitialized variables: " , requested_input(1))); |
2419 | OP_REQUIRES( |
2420 | ctx, gradient_squared_accum.IsInitialized(), |
2421 | errors::FailedPrecondition( |
2422 | "Attempting to use uninitialized variables: " , requested_input(2))); |
2423 | OP_REQUIRES( |
2424 | ctx, var.shape().IsSameSize(gradient_accum.shape()), |
2425 | errors::InvalidArgument("var and accum do not have the same shape" , |
2426 | var.shape().DebugString(), " " , |
2427 | gradient_accum.shape().DebugString())); |
2428 | OP_REQUIRES( |
2429 | ctx, var.shape().IsSameSize(gradient_squared_accum.shape()), |
2430 | errors::InvalidArgument("var and accum do not have the same shape" , |
2431 | var.shape().DebugString(), " " , |
2432 | gradient_squared_accum.shape().DebugString())); |
2433 | |
2434 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
2435 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
2436 | |
2437 | const Tensor& grad = ctx->input(3); |
2438 | const Tensor& indices = ctx->input(4); |
2439 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
2440 | errors::InvalidArgument("indices must be one-dimensional" )); |
2441 | |
2442 | const Tensor& lr = ctx->input(5); |
2443 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
2444 | errors::InvalidArgument("lr is not a scalar: " , |
2445 | lr.shape().DebugString())); |
2446 | |
2447 | const Tensor& l1 = ctx->input(6); |
2448 | OP_REQUIRES( |
2449 | ctx, TensorShapeUtils::IsScalar(l1.shape()), |
2450 | errors::InvalidArgument("l1 regularization strength is not a scalar: " , |
2451 | l1.shape().DebugString())); |
2452 | |
2453 | const Tensor& l2 = ctx->input(7); |
2454 | OP_REQUIRES( |
2455 | ctx, TensorShapeUtils::IsScalar(l2.shape()), |
2456 | errors::InvalidArgument("l2 regularization strength is not a scalar: " , |
2457 | l2.shape().DebugString())); |
2458 | |
2459 | const Tensor& global_step = ctx->input(8); |
2460 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()), |
2461 | errors::InvalidArgument("global_step is not a scalar: " , |
2462 | global_step.shape().DebugString())); |
2463 | |
2464 | int64_t inner_dim = 1; |
2465 | for (int d = 1; d < var.dims(); d++) { |
2466 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
2467 | errors::InvalidArgument(strings::StrCat( |
2468 | "var and grad must match in dimension " , d))); |
2469 | inner_dim *= grad.dim_size(d); |
2470 | } |
2471 | const Tindex N = indices.dim_size(0); |
2472 | OP_REQUIRES( |
2473 | ctx, grad.dim_size(0) == N, |
2474 | errors::InvalidArgument( |
2475 | "grad must be the same size as indices in the first dimension." )); |
2476 | |
2477 | OP_REQUIRES(ctx, inner_dim > 0, |
2478 | errors::InvalidArgument( |
2479 | "Inner dimension should be greater than zero." )); |
2480 | |
2481 | // AdagradDA update: |
2482 | // Let g to be gradient accumulator, gg to be gradient squared accumulator, |
2483 | // T be the global step, lr is the learning rate, and k the initial |
2484 | // gradient squared accumulator value. |
2485 | // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})} |
2486 | if (N > 0) { |
2487 | if (inner_dim > 1) { |
2488 | const Tindex first_dim_size = var.dim_size(0); |
2489 | auto indices_vec = indices.vec<Tindex>(); |
2490 | auto var_flat = var.flat_outer_dims<T>(); |
2491 | auto gradient_accum_flat = gradient_accum.flat_outer_dims<T>(); |
2492 | auto gradient_squared_accum_flat = |
2493 | gradient_squared_accum.flat_outer_dims<T>(); |
2494 | auto grad_flat = grad.flat_outer_dims<T>(); |
2495 | T lr_scalar = lr.scalar<T>()(); |
2496 | T global_step_scalar = global_step.scalar<int64_t>()(); |
2497 | T l1_scalar = l1.scalar<T>()(); |
2498 | T l2_scalar = l2.scalar<T>()(); |
2499 | const double gs_lr = global_step_scalar * lr_scalar; |
2500 | |
2501 | for (Tindex i = 0; i < N; i++) { |
2502 | const Tindex index = internal::SubtleMustCopy(indices_vec(i)); |
2503 | OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), |
2504 | errors::InvalidArgument( |
2505 | strings::StrCat("Index " , index, " at offset " , i, |
2506 | " in indices is out of range" ))); |
2507 | auto ga = gradient_accum_flat.template chip<0>(index); |
2508 | auto da = gradient_squared_accum_flat.template chip<0>(index); |
2509 | auto g = grad_flat.template chip<0>(i); |
2510 | auto v = var_flat.template chip<0>(index); |
2511 | ga += g; |
2512 | da += g.square(); |
2513 | if (l1_scalar > 0) { |
2514 | v = ga.constant(-1.0) * ga.sign() * |
2515 | ((ga.abs() / ga.constant(global_step_scalar)) - |
2516 | ga.constant(l1_scalar)) |
2517 | .cwiseMax(static_cast<T>(0.0)) / |
2518 | (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr)); |
2519 | } else { |
2520 | v = ga.constant(-1.0) * (ga / ga.constant(global_step_scalar)) / |
2521 | (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr)); |
2522 | } |
2523 | } |
2524 | } else { |
2525 | auto indices_vec = indices.vec<Tindex>(); |
2526 | auto var_flat = var.flat<T>(); |
2527 | auto gradient_accum_flat = gradient_accum.flat<T>(); |
2528 | auto gradient_squared_accum_flat = gradient_squared_accum.flat<T>(); |
2529 | auto grad_flat = grad.flat<T>(); |
2530 | const double lr_scalar = lr.scalar<T>()(); |
2531 | const int64_t global_step_scalar = global_step.scalar<int64_t>()(); |
2532 | const double l1_scalar = l1.scalar<T>()(); |
2533 | const double l2_scalar = l2.scalar<T>()(); |
2534 | const Tindex first_dim_size = var_flat.size(); |
2535 | const double gs_l1 = global_step_scalar * l1_scalar; |
2536 | const double gs_l2_lr = global_step_scalar * l2_scalar * lr_scalar; |
2537 | |
2538 | for (Tindex i = 0; i < N; i++) { |
2539 | const Tindex index = internal::SubtleMustCopy(indices_vec(i)); |
2540 | OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), |
2541 | errors::InvalidArgument( |
2542 | strings::StrCat("Index " , index, " at offset " , i, |
2543 | " in indices is out of range" ))); |
2544 | T& ga = gradient_accum_flat(index); |
2545 | T& da = gradient_squared_accum_flat(index); |
2546 | const double g = grad_flat(i); |
2547 | ga += g; |
2548 | da += g * g; |
2549 | if (l1_scalar > 0) { |
2550 | var_flat(index) = sgn(-ga) * lr_scalar * |
2551 | std::max((std::abs(ga) - gs_l1), 0.0) / |
2552 | (gs_l2_lr + std::sqrt(da)); |
2553 | } else { |
2554 | var_flat(index) = (-ga * lr_scalar) / (gs_l2_lr + std::sqrt(da)); |
2555 | } |
2556 | } |
2557 | } |
2558 | } |
2559 | |
2560 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
2561 | } |
2562 | |
2563 | private: |
2564 | bool use_exclusive_lock_; |
2565 | }; |
2566 | |
2567 | #define REGISTER_KERNELS(T, Tindices) \ |
2568 | REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradDA") \ |
2569 | .Device(DEVICE_CPU) \ |
2570 | .TypeConstraint<T>("T") \ |
2571 | .TypeConstraint<Tindices>("Tindices"), \ |
2572 | SparseApplyAdagradDAOp<T, Tindices>); \ |
2573 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradDA") \ |
2574 | .Device(DEVICE_CPU) \ |
2575 | .HostMemory("var") \ |
2576 | .HostMemory("gradient_accumulator") \ |
2577 | .HostMemory("gradient_squared_accumulator") \ |
2578 | .TypeConstraint<T>("T") \ |
2579 | .TypeConstraint<Tindices>("Tindices"), \ |
2580 | SparseApplyAdagradDAOp<T, Tindices>); |
2581 | |
2582 | REGISTER_KERNELS(float, int32); |
2583 | REGISTER_KERNELS(float, int64_t); |
2584 | REGISTER_KERNELS(double, int32); |
2585 | REGISTER_KERNELS(double, int64_t); |
2586 | #undef REGISTER_KERNELS |
2587 | |
2588 | template <typename Device, typename T, bool has_l2_shrinkage> |
2589 | class ApplyFtrlOp : public OpKernel { |
2590 | public: |
2591 | explicit ApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
2592 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
2593 | OP_REQUIRES_OK( |
2594 | ctx, ctx->GetAttr("multiply_linear_by_lr" , &multiply_linear_by_lr_)); |
2595 | } |
2596 | |
2597 | void Compute(OpKernelContext* ctx) override { |
2598 | const bool sparse = false; |
2599 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
2600 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
2601 | |
2602 | Tensor var; |
2603 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2604 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
2605 | Tensor accum; |
2606 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2607 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
2608 | Tensor linear; |
2609 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2610 | ctx, 2, use_exclusive_lock_, sparse, &linear)); |
2611 | OP_REQUIRES( |
2612 | ctx, var.IsInitialized(), |
2613 | errors::FailedPrecondition( |
2614 | "Attempting to use uninitialized variables: " , requested_input(0))); |
2615 | OP_REQUIRES( |
2616 | ctx, accum.IsInitialized(), |
2617 | errors::FailedPrecondition( |
2618 | "Attempting to use uninitialized variables: " , requested_input(1))); |
2619 | OP_REQUIRES( |
2620 | ctx, linear.IsInitialized(), |
2621 | errors::FailedPrecondition( |
2622 | "Attempting to use uninitialized variables: " , requested_input(2))); |
2623 | |
2624 | const Tensor& grad = ctx->input(3); |
2625 | OP_REQUIRES( |
2626 | ctx, var.shape().IsSameSize(accum.shape()), |
2627 | errors::InvalidArgument("var and accum do not have the same shape" , |
2628 | var.shape().DebugString(), " " , |
2629 | accum.shape().DebugString())); |
2630 | OP_REQUIRES( |
2631 | ctx, var.shape().IsSameSize(linear.shape()), |
2632 | errors::InvalidArgument("var and linear do not have the same shape" , |
2633 | var.shape().DebugString(), " " , |
2634 | linear.shape().DebugString())); |
2635 | OP_REQUIRES( |
2636 | ctx, var.shape().IsSameSize(grad.shape()), |
2637 | errors::InvalidArgument("var and grad do not have the same shape" , |
2638 | var.shape().DebugString(), " " , |
2639 | grad.shape().DebugString())); |
2640 | |
2641 | const Tensor& lr = ctx->input(4); |
2642 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
2643 | errors::InvalidArgument("lr is not a scalar: " , |
2644 | lr.shape().DebugString())); |
2645 | const Tensor& l1 = ctx->input(5); |
2646 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()), |
2647 | errors::InvalidArgument("l1 regularization strength is not a " |
2648 | "scalar: " , |
2649 | l1.shape().DebugString())); |
2650 | const Tensor& l2 = ctx->input(6); |
2651 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()), |
2652 | errors::InvalidArgument("l2 regularization strength is not a " |
2653 | "scalar: " , |
2654 | l2.shape().DebugString())); |
2655 | const int lr_power_index = has_l2_shrinkage ? 8 : 7; |
2656 | const Tensor& lr_power = ctx->input(lr_power_index); |
2657 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power.shape()), |
2658 | errors::InvalidArgument("lr_power is not a scalar" , |
2659 | lr_power.shape().DebugString())); |
2660 | |
2661 | const Device& device = ctx->template eigen_device<Device>(); |
2662 | if (has_l2_shrinkage) { |
2663 | const Tensor& l2_shrinkage = ctx->input(7); |
2664 | OP_REQUIRES( |
2665 | ctx, TensorShapeUtils::IsScalar(l2_shrinkage.shape()), |
2666 | errors::InvalidArgument("l2 shrinkage regularization strength " |
2667 | "is not a scalar: " , |
2668 | l2_shrinkage.shape().DebugString())); |
2669 | if (multiply_linear_by_lr_) { |
2670 | functor::ApplyFtrlV2MultiplyLinearByLr<Device, T>()( |
2671 | device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(), |
2672 | grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), |
2673 | l2_shrinkage.scalar<T>(), lr_power.scalar<T>()); |
2674 | } else { |
2675 | functor::ApplyFtrlV2<Device, T>()( |
2676 | device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(), |
2677 | grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), |
2678 | l2_shrinkage.scalar<T>(), lr_power.scalar<T>()); |
2679 | } |
2680 | } else if (multiply_linear_by_lr_) { |
2681 | functor::ApplyFtrlMultiplyLinearByLr<Device, T>()( |
2682 | device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(), |
2683 | grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), |
2684 | lr_power.scalar<T>()); |
2685 | } else { |
2686 | functor::ApplyFtrl<Device, T>()(device, var.flat<T>(), accum.flat<T>(), |
2687 | linear.flat<T>(), grad.flat<T>(), |
2688 | lr.scalar<T>(), l1.scalar<T>(), |
2689 | l2.scalar<T>(), lr_power.scalar<T>()); |
2690 | } |
2691 | |
2692 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
2693 | } |
2694 | |
2695 | private: |
2696 | bool use_exclusive_lock_; |
2697 | bool multiply_linear_by_lr_; |
2698 | }; |
2699 | |
2700 | #define REGISTER_KERNELS(D, T) \ |
2701 | REGISTER_KERNEL_BUILDER( \ |
2702 | Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
2703 | ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>); \ |
2704 | REGISTER_KERNEL_BUILDER( \ |
2705 | Name("ResourceApplyFtrl") \ |
2706 | .HostMemory("var") \ |
2707 | .HostMemory("accum") \ |
2708 | .HostMemory("linear") \ |
2709 | .Device(DEVICE_##D) \ |
2710 | .TypeConstraint<T>("T"), \ |
2711 | ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>); |
2712 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
2713 | |
2714 | TF_CALL_half(REGISTER_CPU_KERNELS); |
2715 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
2716 | TF_CALL_float(REGISTER_CPU_KERNELS); |
2717 | TF_CALL_double(REGISTER_CPU_KERNELS); |
2718 | |
2719 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2720 | // Forward declarations of the functor specializations for GPU. |
2721 | namespace functor { |
2722 | #define DECLARE_GPU_SPEC(T) \ |
2723 | template <> \ |
2724 | void ApplyFtrl<GPUDevice, T>::operator()( \ |
2725 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
2726 | typename TTypes<T>::Flat accum, typename TTypes<T>::Flat linear, \ |
2727 | typename TTypes<T>::ConstFlat grad, typename TTypes<T>::ConstScalar lr, \ |
2728 | typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \ |
2729 | typename TTypes<T>::ConstScalar lr_power); \ |
2730 | extern template struct ApplyFtrl<GPUDevice, T>; |
2731 | DECLARE_GPU_SPEC(Eigen::half); |
2732 | DECLARE_GPU_SPEC(float); |
2733 | DECLARE_GPU_SPEC(double); |
2734 | #undef DECLARE_GPU_SPEC |
2735 | } // namespace functor |
2736 | |
2737 | REGISTER_KERNELS(GPU, Eigen::half); |
2738 | REGISTER_KERNELS(GPU, float); |
2739 | REGISTER_KERNELS(GPU, double); |
2740 | #endif |
2741 | #undef REGISTER_CPU_KERNELS |
2742 | #undef REGISTER_KERNELS |
2743 | |
2744 | #define REGISTER_KERNELS(D, T) \ |
2745 | REGISTER_KERNEL_BUILDER( \ |
2746 | Name("ApplyFtrlV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
2747 | ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>); \ |
2748 | REGISTER_KERNEL_BUILDER( \ |
2749 | Name("ResourceApplyFtrlV2") \ |
2750 | .HostMemory("var") \ |
2751 | .HostMemory("accum") \ |
2752 | .HostMemory("linear") \ |
2753 | .Device(DEVICE_##D) \ |
2754 | .TypeConstraint<T>("T"), \ |
2755 | ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>); |
2756 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
2757 | |
2758 | TF_CALL_half(REGISTER_CPU_KERNELS); |
2759 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
2760 | TF_CALL_float(REGISTER_CPU_KERNELS); |
2761 | TF_CALL_double(REGISTER_CPU_KERNELS); |
2762 | |
2763 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2764 | // Forward declarations of the functor specializations for GPU. |
2765 | namespace functor { |
2766 | #define DECLARE_GPU_SPEC(T) \ |
2767 | template <> \ |
2768 | void ApplyFtrlV2<GPUDevice, T>::operator()( \ |
2769 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
2770 | typename TTypes<T>::Flat accum, typename TTypes<T>::Flat linear, \ |
2771 | typename TTypes<T>::ConstFlat grad, typename TTypes<T>::ConstScalar lr, \ |
2772 | typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \ |
2773 | typename TTypes<T>::ConstScalar l2_shrinkage, \ |
2774 | typename TTypes<T>::ConstScalar lr_power); \ |
2775 | extern template struct ApplyFtrlV2<GPUDevice, T>; \ |
2776 | template <> \ |
2777 | void ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T>::operator()( \ |
2778 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
2779 | typename TTypes<T>::Flat accum, typename TTypes<T>::Flat linear, \ |
2780 | typename TTypes<T>::ConstFlat grad, typename TTypes<T>::ConstScalar lr, \ |
2781 | typename TTypes<T>::ConstScalar l1, typename TTypes<T>::ConstScalar l2, \ |
2782 | typename TTypes<T>::ConstScalar l2_shrinkage, \ |
2783 | typename TTypes<T>::ConstScalar lr_power); \ |
2784 | extern template struct ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T>; |
2785 | DECLARE_GPU_SPEC(Eigen::half); |
2786 | DECLARE_GPU_SPEC(float); |
2787 | DECLARE_GPU_SPEC(double); |
2788 | #undef DECLARE_GPU_SPEC |
2789 | } // namespace functor |
2790 | |
2791 | REGISTER_KERNELS(GPU, Eigen::half); |
2792 | REGISTER_KERNELS(GPU, float); |
2793 | REGISTER_KERNELS(GPU, double); |
2794 | #endif |
2795 | #undef REGISTER_CPU_KERNELS |
2796 | #undef REGISTER_KERNELS |
2797 | |
2798 | template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage> |
2799 | class SparseApplyFtrlOp : public OpKernel { |
2800 | public: |
2801 | explicit SparseApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
2802 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
2803 | OP_REQUIRES_OK( |
2804 | ctx, ctx->GetAttr("multiply_linear_by_lr" , &multiply_linear_by_lr_)); |
2805 | } |
2806 | |
2807 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
2808 | const bool sparse = true; |
2809 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
2810 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
2811 | Tensor var; |
2812 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2813 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
2814 | Tensor accum; |
2815 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2816 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
2817 | Tensor linear; |
2818 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
2819 | ctx, 2, use_exclusive_lock_, sparse, &linear)); |
2820 | OP_REQUIRES( |
2821 | ctx, var.IsInitialized(), |
2822 | errors::FailedPrecondition( |
2823 | "Attempting to use uninitialized variables: " , requested_input(0))); |
2824 | OP_REQUIRES( |
2825 | ctx, accum.IsInitialized(), |
2826 | errors::FailedPrecondition( |
2827 | "Attempting to use uninitialized variables: " , requested_input(1))); |
2828 | OP_REQUIRES( |
2829 | ctx, linear.IsInitialized(), |
2830 | errors::FailedPrecondition( |
2831 | "Attempting to use uninitialized variables: " , requested_input(2))); |
2832 | OP_REQUIRES( |
2833 | ctx, var.shape().IsSameSize(accum.shape()), |
2834 | errors::InvalidArgument("var and accum do not have the same shape" , |
2835 | var.shape().DebugString(), " " , |
2836 | accum.shape().DebugString())); |
2837 | OP_REQUIRES( |
2838 | ctx, var.shape().IsSameSize(linear.shape()), |
2839 | errors::InvalidArgument("var and linear do not have the same shape" , |
2840 | var.shape().DebugString(), " " , |
2841 | linear.shape().DebugString())); |
2842 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
2843 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
2844 | |
2845 | const Tensor& grad = ctx->input(3); |
2846 | const Tensor& indices = ctx->input(4); |
2847 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
2848 | errors::InvalidArgument("indices must be one-dimensional" )); |
2849 | |
2850 | // Note: The range checks on lr, l1, l2, and lr_power below are disabled |
2851 | // for non-CPU devices because their values cannot be accessed directly from |
2852 | // the host. The GPU kernel will not crash if these conditions are not met, |
2853 | // it will simply produce a bogus answer (possibly inf/nan). |
2854 | const Tensor& lr = ctx->input(5); |
2855 | OP_REQUIRES( |
2856 | ctx, |
2857 | TensorShapeUtils::IsScalar(lr.shape()) && |
2858 | (!std::is_same<Device, CPUDevice>::value || |
2859 | lr.scalar<T>()() > static_cast<T>(0) || |
2860 | (multiply_linear_by_lr_ && lr.scalar<T>()() >= static_cast<T>(0))), |
2861 | errors::InvalidArgument("lr is not a positive scalar (or zero if " |
2862 | "multiply_linear_by_lr is set): " , |
2863 | lr.shape().DebugString())); |
2864 | |
2865 | const Tensor& l1 = ctx->input(6); |
2866 | OP_REQUIRES(ctx, |
2867 | TensorShapeUtils::IsScalar(l1.shape()) && |
2868 | (!std::is_same<Device, CPUDevice>::value || |
2869 | l1.scalar<T>()() >= static_cast<T>(0)), |
2870 | errors::InvalidArgument("l1 regularization strength is not a " |
2871 | "non-negative scalar: " , |
2872 | l1.shape().DebugString())); |
2873 | const Tensor& l2 = ctx->input(7); |
2874 | OP_REQUIRES(ctx, |
2875 | TensorShapeUtils::IsScalar(l2.shape()) && |
2876 | (!std::is_same<Device, CPUDevice>::value || |
2877 | l2.scalar<T>()() >= static_cast<T>(0)), |
2878 | errors::InvalidArgument("l2 regularization strength is not a " |
2879 | "non-negative scalar: " , |
2880 | l2.shape().DebugString())); |
2881 | const int lr_power_index = has_l2_shrinkage ? 9 : 8; |
2882 | const Tensor& lr_power = ctx->input(lr_power_index); |
2883 | OP_REQUIRES(ctx, |
2884 | TensorShapeUtils::IsScalar(lr_power.shape()) && |
2885 | (!std::is_same<Device, CPUDevice>::value || |
2886 | lr_power.scalar<T>()() <= static_cast<T>(0)), |
2887 | errors::InvalidArgument("lr_power is not a " |
2888 | "non-positive scalar: " , |
2889 | lr_power.shape().DebugString())); |
2890 | int64_t inner_dim = 1; |
2891 | for (int d = 1; d < var.dims(); d++) { |
2892 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
2893 | errors::InvalidArgument(strings::StrCat( |
2894 | "var and grad must match in dimension " , d))); |
2895 | inner_dim *= grad.dim_size(d); |
2896 | } |
2897 | const Tindex N = indices.dim_size(0); |
2898 | OP_REQUIRES( |
2899 | ctx, grad.dim_size(0) == N, |
2900 | errors::InvalidArgument( |
2901 | "grad must be the same size as indices in the first dimension." )); |
2902 | |
2903 | OP_REQUIRES(ctx, inner_dim > 0, |
2904 | errors::InvalidArgument( |
2905 | "Inner dimension should be greater than zero." )); |
2906 | |
2907 | const Tensor* l2_shrinkage; |
2908 | if (has_l2_shrinkage) { |
2909 | l2_shrinkage = &ctx->input(8); |
2910 | OP_REQUIRES( |
2911 | ctx, |
2912 | TensorShapeUtils::IsScalar(l2_shrinkage->shape()) && |
2913 | (!std::is_same<Device, CPUDevice>::value || |
2914 | l2_shrinkage->scalar<T>()() >= static_cast<T>(0)), |
2915 | errors::InvalidArgument("l2 shrinkage regularization strength " |
2916 | "is not a non-negative scalar: " , |
2917 | l2_shrinkage->shape().DebugString())); |
2918 | } |
2919 | |
2920 | const Device& device = ctx->template eigen_device<Device>(); |
2921 | auto indices_vec = indices.vec<Tindex>(); |
2922 | OP_REQUIRES_OK( |
2923 | ctx, functor::SparseApplyFtrl<Device, T, Tindex, has_l2_shrinkage>()( |
2924 | device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(), |
2925 | linear.flat_outer_dims<T>(), lr.scalar<T>(), l1.scalar<T>(), |
2926 | l2.scalar<T>(), |
2927 | // Note: Passing l2 as a placeholder when not has_l2_shrinkage |
2928 | // (it will not be used). |
2929 | has_l2_shrinkage ? l2_shrinkage->scalar<T>() : l2.scalar<T>(), |
2930 | lr_power.scalar<T>(), grad.flat_outer_dims<T>(), indices_vec, |
2931 | inner_dim, multiply_linear_by_lr_)); |
2932 | |
2933 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
2934 | } |
2935 | |
2936 | private: |
2937 | bool use_exclusive_lock_; |
2938 | bool multiply_linear_by_lr_; |
2939 | }; |
2940 | |
2941 | #define REGISTER_KERNELS(D, T, Tindices) \ |
2942 | REGISTER_KERNEL_BUILDER( \ |
2943 | Name("SparseApplyFtrl") \ |
2944 | .Device(DEVICE_##D) \ |
2945 | .TypeConstraint<T>("T") \ |
2946 | .TypeConstraint<Tindices>("Tindices"), \ |
2947 | SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/false>); \ |
2948 | REGISTER_KERNEL_BUILDER( \ |
2949 | Name("ResourceSparseApplyFtrl") \ |
2950 | .Device(DEVICE_##D) \ |
2951 | .TypeConstraint<T>("T") \ |
2952 | .TypeConstraint<Tindices>("Tindices"), \ |
2953 | SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/false>); |
2954 | #define REGISTER_CPU_KERNELS(T) \ |
2955 | REGISTER_KERNELS(CPU, T, int32); \ |
2956 | REGISTER_KERNELS(CPU, T, int64_t); |
2957 | |
2958 | TF_CALL_half(REGISTER_CPU_KERNELS); |
2959 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
2960 | TF_CALL_float(REGISTER_CPU_KERNELS); |
2961 | TF_CALL_double(REGISTER_CPU_KERNELS); |
2962 | |
2963 | #undef REGISTER_CPU_KERNELS |
2964 | |
2965 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2966 | // Forward declarations of the functor specializations for GPU. |
2967 | namespace functor { |
2968 | #define DECLARE_GPU_SPEC(T, Tindex) \ |
2969 | template <> \ |
2970 | Status SparseApplyFtrl<GPUDevice, T, Tindex, /*has_l2_shrinkage=*/false>:: \ |
2971 | operator()( \ |
2972 | const GPUDevice& d, typename TTypes<T>::Matrix var, \ |
2973 | typename TTypes<T>::Matrix accum, typename TTypes<T>::Matrix linear, \ |
2974 | typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar l1, \ |
2975 | typename TTypes<T>::ConstScalar l2, \ |
2976 | typename TTypes<T>::ConstScalar l2_shrinkage, \ |
2977 | typename TTypes<T>::ConstScalar lr_power, \ |
2978 | typename TTypes<T>::ConstMatrix grad, \ |
2979 | typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \ |
2980 | bool multiply_linear_by_lr); \ |
2981 | extern template struct SparseApplyFtrl<GPUDevice, T, Tindex, \ |
2982 | /*has_l2_shrinkage=*/false>; |
2983 | DECLARE_GPU_SPEC(Eigen::half, int32); |
2984 | DECLARE_GPU_SPEC(Eigen::half, int64_t); |
2985 | DECLARE_GPU_SPEC(float, int32); |
2986 | DECLARE_GPU_SPEC(float, int64_t); |
2987 | DECLARE_GPU_SPEC(double, int32); |
2988 | DECLARE_GPU_SPEC(double, int64_t); |
2989 | #undef DECLARE_GPU_SPEC |
2990 | } // namespace functor |
2991 | |
2992 | REGISTER_KERNELS(GPU, Eigen::half, int32); |
2993 | REGISTER_KERNELS(GPU, Eigen::half, int64_t); |
2994 | REGISTER_KERNELS(GPU, float, int32); |
2995 | REGISTER_KERNELS(GPU, float, int64_t); |
2996 | REGISTER_KERNELS(GPU, double, int32); |
2997 | REGISTER_KERNELS(GPU, double, int64_t); |
2998 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2999 | #undef REGISTER_KERNELS |
3000 | |
3001 | #define REGISTER_KERNELS(D, T, Tindices) \ |
3002 | REGISTER_KERNEL_BUILDER( \ |
3003 | Name("SparseApplyFtrlV2") \ |
3004 | .Device(DEVICE_##D) \ |
3005 | .TypeConstraint<T>("T") \ |
3006 | .TypeConstraint<Tindices>("Tindices"), \ |
3007 | SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/true>); \ |
3008 | REGISTER_KERNEL_BUILDER( \ |
3009 | Name("ResourceSparseApplyFtrlV2") \ |
3010 | .Device(DEVICE_##D) \ |
3011 | .TypeConstraint<T>("T") \ |
3012 | .TypeConstraint<Tindices>("Tindices"), \ |
3013 | SparseApplyFtrlOp<D##Device, T, Tindices, /*has_l2_shrinkage=*/true>); |
3014 | #define REGISTER_CPU_KERNELS(T) \ |
3015 | REGISTER_KERNELS(CPU, T, int32); \ |
3016 | REGISTER_KERNELS(CPU, T, int64_t); |
3017 | |
3018 | TF_CALL_half(REGISTER_CPU_KERNELS); |
3019 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
3020 | TF_CALL_float(REGISTER_CPU_KERNELS); |
3021 | TF_CALL_double(REGISTER_CPU_KERNELS); |
3022 | |
3023 | #undef REGISTER_CPU_KERNELS |
3024 | |
3025 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3026 | // Forward declarations of the functor specializations for GPU. |
3027 | namespace functor { |
3028 | #define DECLARE_GPU_SPEC(T, Tindex) \ |
3029 | template <> \ |
3030 | Status SparseApplyFtrl<GPUDevice, T, Tindex, /*has_l2_shrinkage=*/true>:: \ |
3031 | operator()( \ |
3032 | const GPUDevice& d, typename TTypes<T>::Matrix var, \ |
3033 | typename TTypes<T>::Matrix accum, typename TTypes<T>::Matrix linear, \ |
3034 | typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar l1, \ |
3035 | typename TTypes<T>::ConstScalar l2, \ |
3036 | typename TTypes<T>::ConstScalar l2_shrinkage, \ |
3037 | typename TTypes<T>::ConstScalar lr_power, \ |
3038 | typename TTypes<T>::ConstMatrix grad, \ |
3039 | typename TTypes<Tindex>::ConstVec indices, int64_t inner_dim, \ |
3040 | bool multiply_linear_by_lr); \ |
3041 | extern template struct SparseApplyFtrl<GPUDevice, T, Tindex, \ |
3042 | /*has_l2_shrinkage=*/true>; |
3043 | DECLARE_GPU_SPEC(Eigen::half, int32); |
3044 | DECLARE_GPU_SPEC(Eigen::half, int64_t); |
3045 | DECLARE_GPU_SPEC(float, int32); |
3046 | DECLARE_GPU_SPEC(float, int64_t); |
3047 | DECLARE_GPU_SPEC(double, int32); |
3048 | DECLARE_GPU_SPEC(double, int64_t); |
3049 | #undef DECLARE_GPU_SPEC |
3050 | } // namespace functor |
3051 | |
3052 | REGISTER_KERNELS(GPU, Eigen::half, int32); |
3053 | REGISTER_KERNELS(GPU, Eigen::half, int64_t); |
3054 | REGISTER_KERNELS(GPU, float, int32); |
3055 | REGISTER_KERNELS(GPU, float, int64_t); |
3056 | REGISTER_KERNELS(GPU, double, int32); |
3057 | REGISTER_KERNELS(GPU, double, int64_t); |
3058 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3059 | #undef REGISTER_KERNELS |
3060 | |
3061 | template <typename Device, typename T> |
3062 | class ApplyMomentumOp : public OpKernel { |
3063 | public: |
3064 | explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
3065 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3066 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov" , &use_nesterov_)); |
3067 | } |
3068 | |
3069 | void Compute(OpKernelContext* ctx) override { |
3070 | const bool sparse = false; |
3071 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
3072 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
3073 | |
3074 | Tensor var; |
3075 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3076 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3077 | Tensor accum; |
3078 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3079 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
3080 | OP_REQUIRES( |
3081 | ctx, var.IsInitialized(), |
3082 | errors::FailedPrecondition( |
3083 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3084 | OP_REQUIRES( |
3085 | ctx, accum.IsInitialized(), |
3086 | errors::FailedPrecondition( |
3087 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3088 | const Tensor& lr = ctx->input(2); |
3089 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3090 | errors::InvalidArgument("lr is not a scalar: " , |
3091 | lr.shape().DebugString())); |
3092 | const Tensor& grad = ctx->input(3); |
3093 | OP_REQUIRES( |
3094 | ctx, var.shape().IsSameSize(accum.shape()), |
3095 | errors::InvalidArgument("var and accum do not have the same shape" , |
3096 | var.shape().DebugString(), " " , |
3097 | accum.shape().DebugString())); |
3098 | OP_REQUIRES( |
3099 | ctx, var.shape().IsSameSize(grad.shape()), |
3100 | errors::InvalidArgument("var and grad do not have the same shape" , |
3101 | var.shape().DebugString(), " " , |
3102 | grad.shape().DebugString())); |
3103 | |
3104 | const Tensor& momentum = ctx->input(4); |
3105 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
3106 | errors::InvalidArgument("momentum is not a scalar: " , |
3107 | momentum.shape().DebugString())); |
3108 | |
3109 | const Device& device = ctx->template eigen_device<Device>(); |
3110 | functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(), |
3111 | lr.scalar<T>(), grad.flat<T>(), |
3112 | momentum.scalar<T>(), use_nesterov_); |
3113 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
3114 | } |
3115 | |
3116 | private: |
3117 | bool use_exclusive_lock_; |
3118 | bool use_nesterov_; |
3119 | }; |
3120 | |
3121 | #define REGISTER_KERNELS(D, T) \ |
3122 | REGISTER_KERNEL_BUILDER( \ |
3123 | Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
3124 | ApplyMomentumOp<D##Device, T>); \ |
3125 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyMomentum") \ |
3126 | .Device(DEVICE_##D) \ |
3127 | .HostMemory("var") \ |
3128 | .HostMemory("accum") \ |
3129 | .TypeConstraint<T>("T"), \ |
3130 | ApplyMomentumOp<D##Device, T>); |
3131 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
3132 | |
3133 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
3134 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
3135 | |
3136 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3137 | // Forward declarations of the functor specializations for GPU. |
3138 | namespace functor { |
3139 | #define DECLARE_GPU_SPEC(T) \ |
3140 | template <> \ |
3141 | void ApplyMomentum<GPUDevice, T>::operator()( \ |
3142 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
3143 | typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ |
3144 | typename TTypes<T>::ConstFlat grad, \ |
3145 | typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \ |
3146 | extern template struct ApplyMomentum<GPUDevice, T>; |
3147 | DECLARE_GPU_SPEC(Eigen::half); |
3148 | DECLARE_GPU_SPEC(float); |
3149 | DECLARE_GPU_SPEC(double); |
3150 | DECLARE_GPU_SPEC(complex64); |
3151 | DECLARE_GPU_SPEC(complex128); |
3152 | #undef DECLARE_GPU_SPEC |
3153 | } // namespace functor |
3154 | |
3155 | REGISTER_KERNELS(GPU, Eigen::half); |
3156 | REGISTER_KERNELS(GPU, float); |
3157 | REGISTER_KERNELS(GPU, double); |
3158 | REGISTER_KERNELS(GPU, complex64); |
3159 | REGISTER_KERNELS(GPU, complex128); |
3160 | #endif |
3161 | #undef REGISTER_CPU_KERNELS |
3162 | #undef REGISTER_KERNELS |
3163 | |
3164 | // Note, this op works on cpu only. |
3165 | template <typename T, typename Tindex> |
3166 | class SparseApplyMomentumOp : public OpKernel { |
3167 | public: |
3168 | explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
3169 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3170 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov" , &use_nesterov_)); |
3171 | } |
3172 | |
3173 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
3174 | const bool sparse = true; |
3175 | auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>( |
3176 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
3177 | |
3178 | Tensor var; |
3179 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
3180 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3181 | Tensor accum; |
3182 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
3183 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
3184 | OP_REQUIRES( |
3185 | ctx, var.IsInitialized(), |
3186 | errors::FailedPrecondition( |
3187 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3188 | OP_REQUIRES( |
3189 | ctx, accum.IsInitialized(), |
3190 | errors::FailedPrecondition( |
3191 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3192 | OP_REQUIRES( |
3193 | ctx, var.shape().IsSameSize(accum.shape()), |
3194 | errors::InvalidArgument("var and accum do not have the same shape" , |
3195 | var.shape().DebugString(), " " , |
3196 | accum.shape().DebugString())); |
3197 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
3198 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
3199 | |
3200 | const Tensor& lr = ctx->input(2); |
3201 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3202 | errors::InvalidArgument("lr is not a scalar : " , |
3203 | lr.shape().DebugString())); |
3204 | const Tensor& grad = ctx->input(3); |
3205 | const Tensor& indices = ctx->input(4); |
3206 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
3207 | errors::InvalidArgument("indices must be one-dimensional" )); |
3208 | |
3209 | for (int d = 1; d < var.dims(); d++) { |
3210 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
3211 | errors::InvalidArgument(strings::StrCat( |
3212 | "var and grad must match in dimension " , d))); |
3213 | } |
3214 | const Tindex N = indices.dim_size(0); |
3215 | OP_REQUIRES( |
3216 | ctx, grad.dim_size(0) == N, |
3217 | errors::InvalidArgument( |
3218 | "grad must be the same size as indices in the first dimension." )); |
3219 | |
3220 | const Tensor& momentum = ctx->input(5); |
3221 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
3222 | errors::InvalidArgument("momentum is not a scalar: " , |
3223 | momentum.shape().DebugString())); |
3224 | |
3225 | if (N > 0) { |
3226 | const Tindex first_dim_size = var.dim_size(0); |
3227 | auto indices_vec = indices.vec<Tindex>(); |
3228 | auto var_flat = var.flat_outer_dims<T>(); |
3229 | auto accum_flat = accum.flat_outer_dims<T>(); |
3230 | auto grad_flat = grad.flat_outer_dims<T>(); |
3231 | T lr_scalar = lr.scalar<T>()(); |
3232 | T momentum_scalar = momentum.scalar<T>()(); |
3233 | |
3234 | for (Tindex i = 0; i < N; i++) { |
3235 | const Tindex index = internal::SubtleMustCopy(indices_vec(i)); |
3236 | OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), |
3237 | errors::InvalidArgument( |
3238 | strings::StrCat("Index " , index, " at offset " , i, |
3239 | " in indices is out of range" ))); |
3240 | auto a = accum_flat.template chip<0>(index); |
3241 | auto g = grad_flat.template chip<0>(i); |
3242 | auto v = var_flat.template chip<0>(index); |
3243 | a = a * a.constant(momentum_scalar) + g; |
3244 | if (use_nesterov_) { |
3245 | v -= g.constant(lr_scalar) * g + |
3246 | a.constant(lr_scalar) * a.constant(momentum_scalar) * a; |
3247 | } else { |
3248 | v -= a.constant(lr_scalar) * a; |
3249 | } |
3250 | } |
3251 | } |
3252 | |
3253 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
3254 | } |
3255 | |
3256 | private: |
3257 | bool use_exclusive_lock_; |
3258 | bool use_nesterov_; |
3259 | }; |
3260 | |
3261 | #define REGISTER_KERNELS(T, Tindices) \ |
3262 | REGISTER_KERNEL_BUILDER(Name("SparseApplyMomentum") \ |
3263 | .Device(DEVICE_CPU) \ |
3264 | .TypeConstraint<T>("T") \ |
3265 | .TypeConstraint<Tindices>("Tindices"), \ |
3266 | SparseApplyMomentumOp<T, Tindices>); \ |
3267 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyMomentum") \ |
3268 | .Device(DEVICE_CPU) \ |
3269 | .TypeConstraint<T>("T") \ |
3270 | .TypeConstraint<Tindices>("Tindices"), \ |
3271 | SparseApplyMomentumOp<T, Tindices>); |
3272 | #define REGISTER_CPU_KERNELS(T) \ |
3273 | REGISTER_KERNELS(T, int32); \ |
3274 | REGISTER_KERNELS(T, int64_t); |
3275 | |
3276 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
3277 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
3278 | |
3279 | #undef REGISTER_CPU_KERNELS |
3280 | #undef REGISTER_KERNELS |
3281 | |
3282 | template <typename Device, typename T> |
3283 | class ApplyKerasMomentumOp : public OpKernel { |
3284 | public: |
3285 | explicit ApplyKerasMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
3286 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3287 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov" , &use_nesterov_)); |
3288 | } |
3289 | |
3290 | void Compute(OpKernelContext* ctx) override { |
3291 | const bool sparse = false; |
3292 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
3293 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
3294 | |
3295 | Tensor var; |
3296 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3297 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3298 | Tensor accum; |
3299 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3300 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
3301 | OP_REQUIRES( |
3302 | ctx, var.IsInitialized(), |
3303 | errors::FailedPrecondition( |
3304 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3305 | OP_REQUIRES( |
3306 | ctx, accum.IsInitialized(), |
3307 | errors::FailedPrecondition( |
3308 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3309 | const Tensor& lr = ctx->input(2); |
3310 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3311 | errors::InvalidArgument("lr is not a scalar: " , |
3312 | lr.shape().DebugString())); |
3313 | const Tensor& grad = ctx->input(3); |
3314 | OP_REQUIRES( |
3315 | ctx, var.shape().IsSameSize(accum.shape()), |
3316 | errors::InvalidArgument("var and accum do not have the same shape" , |
3317 | var.shape().DebugString(), " " , |
3318 | accum.shape().DebugString())); |
3319 | OP_REQUIRES( |
3320 | ctx, var.shape().IsSameSize(grad.shape()), |
3321 | errors::InvalidArgument("var and grad do not have the same shape" , |
3322 | var.shape().DebugString(), " " , |
3323 | grad.shape().DebugString())); |
3324 | |
3325 | const Tensor& momentum = ctx->input(4); |
3326 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
3327 | errors::InvalidArgument("momentum is not a scalar: " , |
3328 | momentum.shape().DebugString())); |
3329 | |
3330 | const Device& device = ctx->template eigen_device<Device>(); |
3331 | functor::ApplyKerasMomentum<Device, T>()( |
3332 | device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), grad.flat<T>(), |
3333 | momentum.scalar<T>(), use_nesterov_); |
3334 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
3335 | } |
3336 | |
3337 | private: |
3338 | bool use_exclusive_lock_; |
3339 | bool use_nesterov_; |
3340 | }; |
3341 | |
3342 | #define REGISTER_KERNELS(D, T) \ |
3343 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyKerasMomentum") \ |
3344 | .Device(DEVICE_##D) \ |
3345 | .HostMemory("var") \ |
3346 | .HostMemory("accum") \ |
3347 | .TypeConstraint<T>("T"), \ |
3348 | ApplyKerasMomentumOp<D##Device, T>); |
3349 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
3350 | |
3351 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
3352 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
3353 | |
3354 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3355 | // Forward declarations of the functor specializations for GPU. |
3356 | namespace functor { |
3357 | #define DECLARE_GPU_SPEC(T) \ |
3358 | template <> \ |
3359 | void ApplyKerasMomentum<GPUDevice, T>::operator()( \ |
3360 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
3361 | typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ |
3362 | typename TTypes<T>::ConstFlat grad, \ |
3363 | typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \ |
3364 | extern template struct ApplyKerasMomentum<GPUDevice, T>; |
3365 | DECLARE_GPU_SPEC(Eigen::half); |
3366 | DECLARE_GPU_SPEC(float); |
3367 | DECLARE_GPU_SPEC(double); |
3368 | DECLARE_GPU_SPEC(complex64); |
3369 | DECLARE_GPU_SPEC(complex128); |
3370 | #undef DECLARE_GPU_SPEC |
3371 | } // namespace functor |
3372 | |
3373 | REGISTER_KERNELS(GPU, Eigen::half); |
3374 | REGISTER_KERNELS(GPU, float); |
3375 | REGISTER_KERNELS(GPU, double); |
3376 | REGISTER_KERNELS(GPU, complex64); |
3377 | REGISTER_KERNELS(GPU, complex128); |
3378 | #endif |
3379 | #undef REGISTER_CPU_KERNELS |
3380 | #undef REGISTER_KERNELS |
3381 | |
3382 | // Note, this op works on cpu only. |
3383 | template <typename T, typename Device, typename Tindex> |
3384 | class SparseApplyKerasMomentumOp : public OpKernel { |
3385 | public: |
3386 | explicit SparseApplyKerasMomentumOp(OpKernelConstruction* ctx) |
3387 | : OpKernel(ctx) { |
3388 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3389 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov" , &use_nesterov_)); |
3390 | } |
3391 | |
3392 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
3393 | const bool sparse = true; |
3394 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
3395 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
3396 | |
3397 | Tensor var; |
3398 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3399 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3400 | Tensor accum; |
3401 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3402 | ctx, 1, use_exclusive_lock_, sparse, &accum)); |
3403 | OP_REQUIRES( |
3404 | ctx, var.IsInitialized(), |
3405 | errors::FailedPrecondition( |
3406 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3407 | OP_REQUIRES( |
3408 | ctx, accum.IsInitialized(), |
3409 | errors::FailedPrecondition( |
3410 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3411 | OP_REQUIRES( |
3412 | ctx, var.shape().IsSameSize(accum.shape()), |
3413 | errors::InvalidArgument("var and accum do not have the same shape" , |
3414 | var.shape().DebugString(), " " , |
3415 | accum.shape().DebugString())); |
3416 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
3417 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
3418 | |
3419 | const Tensor& lr = ctx->input(2); |
3420 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3421 | errors::InvalidArgument("lr is not a scalar : " , |
3422 | lr.shape().DebugString())); |
3423 | const Tensor& grad = ctx->input(3); |
3424 | const Tensor& indices = ctx->input(4); |
3425 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
3426 | errors::InvalidArgument("indices must be one-dimensional" )); |
3427 | |
3428 | for (int d = 1; d < var.dims(); d++) { |
3429 | OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), |
3430 | errors::InvalidArgument(strings::StrCat( |
3431 | "var and grad must match in dimension " , d))); |
3432 | } |
3433 | const Tindex N = indices.dim_size(0); |
3434 | OP_REQUIRES( |
3435 | ctx, grad.dim_size(0) == N, |
3436 | errors::InvalidArgument( |
3437 | "grad must be the same size as indices in the first dimension." )); |
3438 | |
3439 | const Tensor& momentum = ctx->input(5); |
3440 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
3441 | errors::InvalidArgument("momentum is not a scalar: " , |
3442 | momentum.shape().DebugString())); |
3443 | |
3444 | const Device& device = ctx->template eigen_device<Device>(); |
3445 | auto indices_flat = indices.flat<Tindex>(); |
3446 | const Tindex bad_i = functor::SparseApplyKerasMomentum<Device, T, Tindex>()( |
3447 | device, var.flat_outer_dims<T>(), accum.flat_outer_dims<T>(), |
3448 | lr.scalar<T>(), grad.flat_outer_dims<T>(), indices_flat, |
3449 | momentum.scalar<T>(), use_nesterov_); |
3450 | OP_REQUIRES( |
3451 | ctx, bad_i < 0, |
3452 | errors::InvalidArgument( |
3453 | "indices" , SliceDebugString(indices.shape(), bad_i), " = " , |
3454 | indices_flat(bad_i), " is not in [0, " , var.dim_size(0), ")" )); |
3455 | |
3456 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
3457 | } |
3458 | |
3459 | private: |
3460 | bool use_exclusive_lock_; |
3461 | bool use_nesterov_; |
3462 | }; |
3463 | |
3464 | #define REGISTER_KERNELS(T, D, Tindices) \ |
3465 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyKerasMomentum") \ |
3466 | .Device(DEVICE_##D) \ |
3467 | .TypeConstraint<T>("T") \ |
3468 | .TypeConstraint<Tindices>("Tindices"), \ |
3469 | SparseApplyKerasMomentumOp<T, D##Device, Tindices>); |
3470 | #define REGISTER_CPU_KERNELS(T) \ |
3471 | REGISTER_KERNELS(T, CPU, int32); \ |
3472 | REGISTER_KERNELS(T, CPU, int64_t); |
3473 | |
3474 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
3475 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
3476 | |
3477 | #undef REGISTER_CPU_KERNELS |
3478 | |
3479 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3480 | // Forward declarations of the functor specializations for GPU. |
3481 | namespace functor { |
3482 | #define DECLARE_GPU_SPEC(T, Tindex) \ |
3483 | template <> \ |
3484 | Tindex SparseApplyKerasMomentum<GPUDevice, T, Tindex>::operator()( \ |
3485 | const GPUDevice& d, typename TTypes<T>::Matrix var, \ |
3486 | typename TTypes<T>::Matrix accum, typename TTypes<T>::ConstScalar lr, \ |
3487 | typename TTypes<T>::ConstMatrix grad, \ |
3488 | typename TTypes<Tindex>::ConstFlat indices, \ |
3489 | typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \ |
3490 | extern template struct SparseApplyKerasMomentum<GPUDevice, T, Tindex>; |
3491 | DECLARE_GPU_SPEC(Eigen::half, int32); |
3492 | DECLARE_GPU_SPEC(Eigen::half, int64_t); |
3493 | DECLARE_GPU_SPEC(float, int32); |
3494 | DECLARE_GPU_SPEC(float, int64_t); |
3495 | DECLARE_GPU_SPEC(double, int32); |
3496 | DECLARE_GPU_SPEC(double, int64_t); |
3497 | DECLARE_GPU_SPEC(complex64, int32); |
3498 | DECLARE_GPU_SPEC(complex64, int64_t); |
3499 | DECLARE_GPU_SPEC(complex128, int32); |
3500 | DECLARE_GPU_SPEC(complex128, int64_t); |
3501 | #undef DECLARE_GPU_SPEC |
3502 | } // namespace functor |
3503 | |
3504 | #define REGISTER_GPU_KERNELS(T) \ |
3505 | REGISTER_KERNELS(T, GPU, int32); \ |
3506 | REGISTER_KERNELS(T, GPU, int64_t); |
3507 | |
3508 | REGISTER_GPU_KERNELS(Eigen::half); |
3509 | REGISTER_GPU_KERNELS(float); |
3510 | REGISTER_GPU_KERNELS(double); |
3511 | REGISTER_GPU_KERNELS(complex64); |
3512 | REGISTER_GPU_KERNELS(complex128); |
3513 | #undef REGISTER_GPU_KERNELS |
3514 | #endif |
3515 | #undef REGISTER_KERNELS |
3516 | |
3517 | template <typename Device, typename T> |
3518 | class ApplyAdamOp : public OpKernel { |
3519 | public: |
3520 | explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
3521 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3522 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov" , &use_nesterov_)); |
3523 | } |
3524 | |
3525 | void Compute(OpKernelContext* ctx) override { |
3526 | const bool sparse = false; |
3527 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
3528 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
3529 | |
3530 | Tensor var; |
3531 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3532 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3533 | Tensor m; |
3534 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3535 | ctx, 1, use_exclusive_lock_, sparse, &m)); |
3536 | Tensor v; |
3537 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3538 | ctx, 2, use_exclusive_lock_, sparse, &v)); |
3539 | OP_REQUIRES( |
3540 | ctx, var.IsInitialized(), |
3541 | errors::FailedPrecondition( |
3542 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3543 | OP_REQUIRES( |
3544 | ctx, m.IsInitialized(), |
3545 | errors::FailedPrecondition( |
3546 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3547 | OP_REQUIRES( |
3548 | ctx, v.IsInitialized(), |
3549 | errors::FailedPrecondition( |
3550 | "Attempting to use uninitialized variables: " , requested_input(2))); |
3551 | |
3552 | const Tensor& beta1_power = ctx->input(3); |
3553 | const Tensor& beta2_power = ctx->input(4); |
3554 | const Tensor& lr = ctx->input(5); |
3555 | const Tensor& beta1 = ctx->input(6); |
3556 | const Tensor& beta2 = ctx->input(7); |
3557 | const Tensor& epsilon = ctx->input(8); |
3558 | |
3559 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), |
3560 | errors::InvalidArgument("beta1_power is not a scalar: " , |
3561 | beta1_power.shape().DebugString())); |
3562 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()), |
3563 | errors::InvalidArgument("beta2_power is not a scalar: " , |
3564 | beta2_power.shape().DebugString())); |
3565 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3566 | errors::InvalidArgument("lr is not a scalar : " , |
3567 | lr.shape().DebugString())); |
3568 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), |
3569 | errors::InvalidArgument("beta1 is not a scalar: " , |
3570 | beta1.shape().DebugString())); |
3571 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), |
3572 | errors::InvalidArgument("beta2 is not a scalar: " , |
3573 | beta2.shape().DebugString())); |
3574 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
3575 | errors::InvalidArgument("epsilon is not a scalar: " , |
3576 | epsilon.shape().DebugString())); |
3577 | |
3578 | const Tensor& grad = ctx->input(9); |
3579 | OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), |
3580 | errors::InvalidArgument("var and m do not have the same shape" , |
3581 | var.shape().DebugString(), " " , |
3582 | m.shape().DebugString())); |
3583 | OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), |
3584 | errors::InvalidArgument("var and v do not have the same shape" , |
3585 | var.shape().DebugString(), " " , |
3586 | v.shape().DebugString())); |
3587 | OP_REQUIRES( |
3588 | ctx, var.shape().IsSameSize(grad.shape()), |
3589 | errors::InvalidArgument("var and grad do not have the same shape" , |
3590 | var.shape().DebugString(), " " , |
3591 | grad.shape().DebugString())); |
3592 | |
3593 | const Device& device = ctx->template eigen_device<Device>(); |
3594 | functor::ApplyAdam<Device, T>()( |
3595 | device, var.flat<T>(), m.flat<T>(), v.flat<T>(), |
3596 | beta1_power.scalar<T>(), beta2_power.scalar<T>(), lr.scalar<T>(), |
3597 | beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(), |
3598 | grad.flat<T>(), use_nesterov_); |
3599 | |
3600 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
3601 | } |
3602 | |
3603 | private: |
3604 | bool use_exclusive_lock_; |
3605 | bool use_nesterov_; |
3606 | }; |
3607 | |
3608 | #define REGISTER_KERNELS(D, T) \ |
3609 | REGISTER_KERNEL_BUILDER( \ |
3610 | Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
3611 | ApplyAdamOp<D##Device, T>); \ |
3612 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdam") \ |
3613 | .HostMemory("var") \ |
3614 | .HostMemory("m") \ |
3615 | .HostMemory("v") \ |
3616 | .Device(DEVICE_##D) \ |
3617 | .TypeConstraint<T>("T"), \ |
3618 | ApplyAdamOp<D##Device, T>); |
3619 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
3620 | |
3621 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
3622 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
3623 | |
3624 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3625 | // Forward declarations of the functor specializations for GPU. |
3626 | namespace functor { |
3627 | #define DECLARE_GPU_SPEC(T) \ |
3628 | template <> \ |
3629 | void ApplyAdam<GPUDevice, T>::operator()( \ |
3630 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
3631 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \ |
3632 | typename TTypes<T>::ConstScalar beta1_power, \ |
3633 | typename TTypes<T>::ConstScalar beta2_power, \ |
3634 | typename TTypes<T>::ConstScalar lr, \ |
3635 | typename TTypes<T>::ConstScalar beta1, \ |
3636 | typename TTypes<T>::ConstScalar beta2, \ |
3637 | typename TTypes<T>::ConstScalar epsilon, \ |
3638 | typename TTypes<T>::ConstFlat grad, bool use_nesterov); \ |
3639 | extern template struct ApplyAdam<GPUDevice, T>; |
3640 | DECLARE_GPU_SPEC(Eigen::half); |
3641 | DECLARE_GPU_SPEC(float); |
3642 | DECLARE_GPU_SPEC(double); |
3643 | DECLARE_GPU_SPEC(complex64); |
3644 | DECLARE_GPU_SPEC(complex128); |
3645 | #undef DECLARE_GPU_SPEC |
3646 | } // namespace functor |
3647 | |
3648 | REGISTER_KERNELS(GPU, Eigen::half); |
3649 | REGISTER_KERNELS(GPU, float); |
3650 | REGISTER_KERNELS(GPU, double); |
3651 | REGISTER_KERNELS(GPU, complex64); |
3652 | REGISTER_KERNELS(GPU, complex128); |
3653 | #endif |
3654 | #undef REGISTER_CPU_KERNELS |
3655 | #undef REGISTER_KERNELS |
3656 | |
3657 | template <typename Device, typename T> |
3658 | class ApplyAdamWithAmsgradOp : public OpKernel { |
3659 | public: |
3660 | explicit ApplyAdamWithAmsgradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
3661 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3662 | } |
3663 | |
3664 | void Compute(OpKernelContext* ctx) override { |
3665 | const bool sparse = false; |
3666 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
3667 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
3668 | |
3669 | Tensor var; |
3670 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3671 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3672 | Tensor m; |
3673 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3674 | ctx, 1, use_exclusive_lock_, sparse, &m)); |
3675 | Tensor v; |
3676 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3677 | ctx, 2, use_exclusive_lock_, sparse, &v)); |
3678 | Tensor vhat; |
3679 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3680 | ctx, 3, use_exclusive_lock_, sparse, &vhat)); |
3681 | OP_REQUIRES( |
3682 | ctx, var.IsInitialized(), |
3683 | errors::FailedPrecondition( |
3684 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3685 | OP_REQUIRES( |
3686 | ctx, m.IsInitialized(), |
3687 | errors::FailedPrecondition( |
3688 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3689 | OP_REQUIRES( |
3690 | ctx, v.IsInitialized(), |
3691 | errors::FailedPrecondition( |
3692 | "Attempting to use uninitialized variables: " , requested_input(2))); |
3693 | OP_REQUIRES( |
3694 | ctx, vhat.IsInitialized(), |
3695 | errors::FailedPrecondition( |
3696 | "Attempting to use uninitialized variables: " , requested_input(2))); |
3697 | |
3698 | const Tensor& beta1_power = ctx->input(4); |
3699 | const Tensor& beta2_power = ctx->input(5); |
3700 | const Tensor& lr = ctx->input(6); |
3701 | const Tensor& beta1 = ctx->input(7); |
3702 | const Tensor& beta2 = ctx->input(8); |
3703 | const Tensor& epsilon = ctx->input(9); |
3704 | |
3705 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), |
3706 | errors::InvalidArgument("beta1_power is not a scalar: " , |
3707 | beta1_power.shape().DebugString())); |
3708 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()), |
3709 | errors::InvalidArgument("beta2_power is not a scalar: " , |
3710 | beta2_power.shape().DebugString())); |
3711 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3712 | errors::InvalidArgument("lr is not a scalar : " , |
3713 | lr.shape().DebugString())); |
3714 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), |
3715 | errors::InvalidArgument("beta1 is not a scalar: " , |
3716 | beta1.shape().DebugString())); |
3717 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), |
3718 | errors::InvalidArgument("beta2 is not a scalar: " , |
3719 | beta2.shape().DebugString())); |
3720 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
3721 | errors::InvalidArgument("epsilon is not a scalar: " , |
3722 | epsilon.shape().DebugString())); |
3723 | |
3724 | const Tensor& grad = ctx->input(10); |
3725 | OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), |
3726 | errors::InvalidArgument("var and m do not have the same shape" , |
3727 | var.shape().DebugString(), " " , |
3728 | m.shape().DebugString())); |
3729 | OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), |
3730 | errors::InvalidArgument("var and v do not have the same shape" , |
3731 | var.shape().DebugString(), " " , |
3732 | v.shape().DebugString())); |
3733 | OP_REQUIRES( |
3734 | ctx, var.shape().IsSameSize(grad.shape()), |
3735 | errors::InvalidArgument("var and grad do not have the same shape" , |
3736 | var.shape().DebugString(), " " , |
3737 | grad.shape().DebugString())); |
3738 | |
3739 | const Device& device = ctx->template eigen_device<Device>(); |
3740 | functor::ApplyAdamWithAmsgrad<Device, T>()( |
3741 | device, var.flat<T>(), m.flat<T>(), v.flat<T>(), vhat.flat<T>(), |
3742 | beta1_power.scalar<T>(), beta2_power.scalar<T>(), lr.scalar<T>(), |
3743 | beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(), |
3744 | grad.flat<T>()); |
3745 | |
3746 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
3747 | } |
3748 | |
3749 | private: |
3750 | bool use_exclusive_lock_; |
3751 | }; |
3752 | |
3753 | #define REGISTER_KERNELS(D, T) \ |
3754 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdamWithAmsgrad") \ |
3755 | .HostMemory("var") \ |
3756 | .HostMemory("m") \ |
3757 | .HostMemory("v") \ |
3758 | .HostMemory("vhat") \ |
3759 | .Device(DEVICE_##D) \ |
3760 | .TypeConstraint<T>("T"), \ |
3761 | ApplyAdamWithAmsgradOp<D##Device, T>); |
3762 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
3763 | |
3764 | TF_CALL_half(REGISTER_CPU_KERNELS); |
3765 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
3766 | TF_CALL_float(REGISTER_CPU_KERNELS); |
3767 | TF_CALL_double(REGISTER_CPU_KERNELS); |
3768 | |
3769 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3770 | // Forward declarations of the functor specializations for GPU. |
3771 | namespace functor { |
3772 | #define DECLARE_GPU_SPEC(T) \ |
3773 | template <> \ |
3774 | void ApplyAdamWithAmsgrad<GPUDevice, T>::operator()( \ |
3775 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
3776 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \ |
3777 | typename TTypes<T>::Flat vhat, \ |
3778 | typename TTypes<T>::ConstScalar beta1_power, \ |
3779 | typename TTypes<T>::ConstScalar beta2_power, \ |
3780 | typename TTypes<T>::ConstScalar lr, \ |
3781 | typename TTypes<T>::ConstScalar beta1, \ |
3782 | typename TTypes<T>::ConstScalar beta2, \ |
3783 | typename TTypes<T>::ConstScalar epsilon, \ |
3784 | typename TTypes<T>::ConstFlat grad); \ |
3785 | extern template struct ApplyAdamWithAmsgrad<GPUDevice, T>; |
3786 | DECLARE_GPU_SPEC(Eigen::half); |
3787 | DECLARE_GPU_SPEC(float); |
3788 | DECLARE_GPU_SPEC(double); |
3789 | #undef DECLARE_GPU_SPEC |
3790 | } // namespace functor |
3791 | |
3792 | REGISTER_KERNELS(GPU, Eigen::half); |
3793 | REGISTER_KERNELS(GPU, float); |
3794 | REGISTER_KERNELS(GPU, double); |
3795 | #endif |
3796 | #undef REGISTER_CPU_KERNELS |
3797 | #undef REGISTER_KERNELS |
3798 | |
3799 | template <typename Device, typename T> |
3800 | class ApplyAdaMaxOp : public OpKernel { |
3801 | public: |
3802 | explicit ApplyAdaMaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
3803 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3804 | } |
3805 | |
3806 | void Compute(OpKernelContext* ctx) override { |
3807 | const bool sparse = false; |
3808 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
3809 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
3810 | |
3811 | Tensor var; |
3812 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3813 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3814 | Tensor m; |
3815 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3816 | ctx, 1, use_exclusive_lock_, sparse, &m)); |
3817 | Tensor v; |
3818 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3819 | ctx, 2, use_exclusive_lock_, sparse, &v)); |
3820 | OP_REQUIRES( |
3821 | ctx, var.IsInitialized(), |
3822 | errors::FailedPrecondition( |
3823 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3824 | OP_REQUIRES( |
3825 | ctx, m.IsInitialized(), |
3826 | errors::FailedPrecondition( |
3827 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3828 | OP_REQUIRES( |
3829 | ctx, v.IsInitialized(), |
3830 | errors::FailedPrecondition( |
3831 | "Attempting to use uninitialized variables: " , requested_input(2))); |
3832 | |
3833 | const Tensor& beta1_power = ctx->input(3); |
3834 | const Tensor& lr = ctx->input(4); |
3835 | const Tensor& beta1 = ctx->input(5); |
3836 | const Tensor& beta2 = ctx->input(6); |
3837 | const Tensor& epsilon = ctx->input(7); |
3838 | |
3839 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), |
3840 | errors::InvalidArgument("beta1_power is not a scalar: " , |
3841 | beta1_power.shape().DebugString())); |
3842 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3843 | errors::InvalidArgument("lr is not a scalar : " , |
3844 | lr.shape().DebugString())); |
3845 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), |
3846 | errors::InvalidArgument("beta1 is not a scalar: " , |
3847 | beta1.shape().DebugString())); |
3848 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), |
3849 | errors::InvalidArgument("beta2 is not a scalar: " , |
3850 | beta2.shape().DebugString())); |
3851 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
3852 | errors::InvalidArgument("epsilon is not a scalar: " , |
3853 | epsilon.shape().DebugString())); |
3854 | |
3855 | const Tensor& grad = ctx->input(8); |
3856 | OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), |
3857 | errors::InvalidArgument("var and m do not have the same shape" , |
3858 | var.shape().DebugString(), " " , |
3859 | m.shape().DebugString())); |
3860 | OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), |
3861 | errors::InvalidArgument("var and v do not have the same shape" , |
3862 | var.shape().DebugString(), " " , |
3863 | v.shape().DebugString())); |
3864 | OP_REQUIRES( |
3865 | ctx, var.shape().IsSameSize(grad.shape()), |
3866 | errors::InvalidArgument("var and grad do not have the same shape" , |
3867 | var.shape().DebugString(), " " , |
3868 | grad.shape().DebugString())); |
3869 | |
3870 | const Device& device = ctx->template eigen_device<Device>(); |
3871 | functor::ApplyAdaMax<Device, T>()( |
3872 | device, var.flat<T>(), m.flat<T>(), v.flat<T>(), |
3873 | beta1_power.scalar<T>(), lr.scalar<T>(), beta1.scalar<T>(), |
3874 | beta2.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>()); |
3875 | |
3876 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
3877 | } |
3878 | |
3879 | private: |
3880 | bool use_exclusive_lock_; |
3881 | }; |
3882 | |
3883 | #define REGISTER_KERNELS(D, T) \ |
3884 | REGISTER_KERNEL_BUILDER( \ |
3885 | Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
3886 | ApplyAdaMaxOp<D##Device, T>); \ |
3887 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \ |
3888 | .HostMemory("var") \ |
3889 | .HostMemory("m") \ |
3890 | .HostMemory("v") \ |
3891 | .Device(DEVICE_##D) \ |
3892 | .TypeConstraint<T>("T"), \ |
3893 | ApplyAdaMaxOp<D##Device, T>); |
3894 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
3895 | |
3896 | TF_CALL_half(REGISTER_CPU_KERNELS); |
3897 | TF_CALL_float(REGISTER_CPU_KERNELS); |
3898 | TF_CALL_double(REGISTER_CPU_KERNELS); |
3899 | |
3900 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
3901 | // Forward declarations of the functor specializations for GPU. |
3902 | namespace functor { |
3903 | #define DECLARE_GPU_SPEC(T) \ |
3904 | template <> \ |
3905 | void ApplyAdaMax<GPUDevice, T>::operator()( \ |
3906 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
3907 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \ |
3908 | typename TTypes<T>::ConstScalar beta1_power, \ |
3909 | typename TTypes<T>::ConstScalar lr, \ |
3910 | typename TTypes<T>::ConstScalar beta1, \ |
3911 | typename TTypes<T>::ConstScalar beta2, \ |
3912 | typename TTypes<T>::ConstScalar epsilon, \ |
3913 | typename TTypes<T>::ConstFlat grad); \ |
3914 | extern template struct ApplyAdaMax<GPUDevice, T>; |
3915 | DECLARE_GPU_SPEC(Eigen::half); |
3916 | DECLARE_GPU_SPEC(float); |
3917 | DECLARE_GPU_SPEC(double); |
3918 | #undef DECLARE_GPU_SPEC |
3919 | } // namespace functor |
3920 | |
3921 | REGISTER_KERNELS(GPU, Eigen::half); |
3922 | REGISTER_KERNELS(GPU, float); |
3923 | REGISTER_KERNELS(GPU, double); |
3924 | #endif |
3925 | #undef REGISTER_CPU_KERNELS |
3926 | #undef REGISTER_KERNELS |
3927 | |
3928 | template <typename Device, typename T> |
3929 | class ApplyRMSPropOp : public OpKernel { |
3930 | public: |
3931 | explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
3932 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
3933 | } |
3934 | |
3935 | void Compute(OpKernelContext* ctx) override { |
3936 | const bool sparse = false; |
3937 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
3938 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
3939 | |
3940 | Tensor var; |
3941 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3942 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
3943 | Tensor ms; |
3944 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3945 | ctx, 1, use_exclusive_lock_, sparse, &ms)); |
3946 | Tensor mom; |
3947 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
3948 | ctx, 2, use_exclusive_lock_, sparse, &mom)); |
3949 | |
3950 | OP_REQUIRES( |
3951 | ctx, var.IsInitialized(), |
3952 | errors::FailedPrecondition( |
3953 | "Attempting to use uninitialized variables: " , requested_input(0))); |
3954 | OP_REQUIRES( |
3955 | ctx, ms.IsInitialized(), |
3956 | errors::FailedPrecondition( |
3957 | "Attempting to use uninitialized variables: " , requested_input(1))); |
3958 | OP_REQUIRES( |
3959 | ctx, mom.IsInitialized(), |
3960 | errors::FailedPrecondition( |
3961 | "Attempting to use uninitialized variables: " , requested_input(2))); |
3962 | |
3963 | const Tensor& lr = ctx->input(3); |
3964 | const Tensor& rho = ctx->input(4); |
3965 | const Tensor& momentum = ctx->input(5); |
3966 | const Tensor& epsilon = ctx->input(6); |
3967 | const Tensor& grad = ctx->input(7); |
3968 | |
3969 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
3970 | errors::InvalidArgument("lr is not a scalar : " , |
3971 | lr.shape().DebugString())); |
3972 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), |
3973 | errors::InvalidArgument("rho is not a scalar: " , |
3974 | rho.shape().DebugString())); |
3975 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
3976 | errors::InvalidArgument("momentum is not a scalar: " , |
3977 | momentum.shape().DebugString())); |
3978 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
3979 | errors::InvalidArgument("epsilon is not a scalar: " , |
3980 | epsilon.shape().DebugString())); |
3981 | |
3982 | OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), |
3983 | errors::InvalidArgument("var and ms do not have the same shape" , |
3984 | var.shape().DebugString(), " " , |
3985 | ms.shape().DebugString())); |
3986 | |
3987 | OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), |
3988 | errors::InvalidArgument( |
3989 | "var and mom do not have the same shape" , |
3990 | var.shape().DebugString(), " " , mom.shape().DebugString())); |
3991 | |
3992 | OP_REQUIRES( |
3993 | ctx, var.shape().IsSameSize(grad.shape()), |
3994 | errors::InvalidArgument("var and grad do not have the same shape" , |
3995 | var.shape().DebugString(), " " , |
3996 | grad.shape().DebugString())); |
3997 | |
3998 | const Device& device = ctx->template eigen_device<Device>(); |
3999 | functor::ApplyRMSProp<Device, T>()(device, var.flat<T>(), ms.flat<T>(), |
4000 | mom.flat<T>(), lr.scalar<T>(), |
4001 | rho.scalar<T>(), momentum.scalar<T>(), |
4002 | epsilon.scalar<T>(), grad.flat<T>()); |
4003 | |
4004 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
4005 | } |
4006 | |
4007 | private: |
4008 | bool use_exclusive_lock_; |
4009 | }; |
4010 | |
4011 | template <typename Device, typename T> |
4012 | class ApplyCenteredRMSPropOp : public OpKernel { |
4013 | public: |
4014 | explicit ApplyCenteredRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
4015 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
4016 | } |
4017 | |
4018 | void Compute(OpKernelContext* ctx) override { |
4019 | const bool sparse = false; |
4020 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
4021 | ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3}); |
4022 | |
4023 | Tensor var; |
4024 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4025 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
4026 | Tensor mg; |
4027 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4028 | ctx, 1, use_exclusive_lock_, sparse, &mg)); |
4029 | Tensor ms; |
4030 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4031 | ctx, 2, use_exclusive_lock_, sparse, &ms)); |
4032 | Tensor mom; |
4033 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4034 | ctx, 3, use_exclusive_lock_, sparse, &mom)); |
4035 | |
4036 | OP_REQUIRES( |
4037 | ctx, var.IsInitialized(), |
4038 | errors::FailedPrecondition( |
4039 | "Attempting to use uninitialized variables: " , requested_input(0))); |
4040 | OP_REQUIRES( |
4041 | ctx, mg.IsInitialized(), |
4042 | errors::FailedPrecondition( |
4043 | "Attempting to use uninitialized variables: " , requested_input(1))); |
4044 | OP_REQUIRES( |
4045 | ctx, ms.IsInitialized(), |
4046 | errors::FailedPrecondition( |
4047 | "Attempting to use uninitialized variables: " , requested_input(2))); |
4048 | OP_REQUIRES( |
4049 | ctx, mom.IsInitialized(), |
4050 | errors::FailedPrecondition( |
4051 | "Attempting to use uninitialized variables: " , requested_input(3))); |
4052 | |
4053 | const Tensor& lr = ctx->input(4); |
4054 | const Tensor& rho = ctx->input(5); |
4055 | const Tensor& momentum = ctx->input(6); |
4056 | const Tensor& epsilon = ctx->input(7); |
4057 | const Tensor& grad = ctx->input(8); |
4058 | |
4059 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
4060 | errors::InvalidArgument("lr is not a scalar : " , |
4061 | lr.shape().DebugString())); |
4062 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), |
4063 | errors::InvalidArgument("rho is not a scalar: " , |
4064 | rho.shape().DebugString())); |
4065 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
4066 | errors::InvalidArgument("momentum is not a scalar: " , |
4067 | momentum.shape().DebugString())); |
4068 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
4069 | errors::InvalidArgument("epsilon is not a scalar: " , |
4070 | epsilon.shape().DebugString())); |
4071 | |
4072 | OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()), |
4073 | errors::InvalidArgument("var and mg do not have the same shape" , |
4074 | var.shape().DebugString(), " " , |
4075 | ms.shape().DebugString())); |
4076 | |
4077 | OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), |
4078 | errors::InvalidArgument("var and ms do not have the same shape" , |
4079 | var.shape().DebugString(), " " , |
4080 | ms.shape().DebugString())); |
4081 | |
4082 | OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), |
4083 | errors::InvalidArgument( |
4084 | "var and mom do not have the same shape" , |
4085 | var.shape().DebugString(), " " , mom.shape().DebugString())); |
4086 | |
4087 | OP_REQUIRES( |
4088 | ctx, var.shape().IsSameSize(grad.shape()), |
4089 | errors::InvalidArgument("var and grad do not have the same shape" , |
4090 | var.shape().DebugString(), " " , |
4091 | grad.shape().DebugString())); |
4092 | |
4093 | const Device& device = ctx->template eigen_device<Device>(); |
4094 | functor::ApplyCenteredRMSProp<Device, T>()( |
4095 | device, var.flat<T>(), mg.flat<T>(), ms.flat<T>(), mom.flat<T>(), |
4096 | lr.scalar<T>(), rho.scalar<T>(), momentum.scalar<T>(), |
4097 | epsilon.scalar<T>(), grad.flat<T>()); |
4098 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
4099 | } |
4100 | |
4101 | private: |
4102 | bool use_exclusive_lock_; |
4103 | }; |
4104 | |
4105 | #define REGISTER_KERNELS(D, T) \ |
4106 | REGISTER_KERNEL_BUILDER( \ |
4107 | Name("ApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
4108 | ApplyRMSPropOp<D##Device, T>); \ |
4109 | REGISTER_KERNEL_BUILDER( \ |
4110 | Name("ApplyCenteredRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
4111 | ApplyCenteredRMSPropOp<D##Device, T>); \ |
4112 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyRMSProp") \ |
4113 | .Device(DEVICE_##D) \ |
4114 | .HostMemory("var") \ |
4115 | .HostMemory("ms") \ |
4116 | .HostMemory("mom") \ |
4117 | .TypeConstraint<T>("T"), \ |
4118 | ApplyRMSPropOp<D##Device, T>); \ |
4119 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyCenteredRMSProp") \ |
4120 | .Device(DEVICE_##D) \ |
4121 | .HostMemory("var") \ |
4122 | .HostMemory("mg") \ |
4123 | .HostMemory("ms") \ |
4124 | .HostMemory("mom") \ |
4125 | .TypeConstraint<T>("T"), \ |
4126 | ApplyCenteredRMSPropOp<D##Device, T>); |
4127 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
4128 | |
4129 | TF_CALL_FLOAT_TYPES(REGISTER_CPU_KERNELS); |
4130 | TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS); |
4131 | |
4132 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
4133 | // Forward declarations of the functor specializations for GPU. |
4134 | namespace functor { |
4135 | #define DECLARE_GPU_SPEC(T) \ |
4136 | template <> \ |
4137 | void ApplyRMSProp<GPUDevice, T>::operator()( \ |
4138 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
4139 | typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, \ |
4140 | typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \ |
4141 | typename TTypes<T>::ConstScalar momentum, \ |
4142 | typename TTypes<T>::ConstScalar epsilon, \ |
4143 | typename TTypes<T>::ConstFlat grad); \ |
4144 | extern template struct ApplyRMSProp<GPUDevice, T>; \ |
4145 | template <> \ |
4146 | void ApplyCenteredRMSProp<GPUDevice, T>::operator()( \ |
4147 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
4148 | typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, \ |
4149 | typename TTypes<T>::Flat mom, typename TTypes<T>::ConstScalar lr, \ |
4150 | typename TTypes<T>::ConstScalar rho, \ |
4151 | typename TTypes<T>::ConstScalar momentum, \ |
4152 | typename TTypes<T>::ConstScalar epsilon, \ |
4153 | typename TTypes<T>::ConstFlat grad); \ |
4154 | extern template struct ApplyCenteredRMSProp<GPUDevice, T>; |
4155 | DECLARE_GPU_SPEC(Eigen::half); |
4156 | DECLARE_GPU_SPEC(float); |
4157 | DECLARE_GPU_SPEC(double); |
4158 | DECLARE_GPU_SPEC(complex64); |
4159 | DECLARE_GPU_SPEC(complex128); |
4160 | #undef DECLARE_GPU_SPEC |
4161 | } // namespace functor |
4162 | |
4163 | REGISTER_KERNELS(GPU, Eigen::half); |
4164 | REGISTER_KERNELS(GPU, float); |
4165 | REGISTER_KERNELS(GPU, double); |
4166 | REGISTER_KERNELS(GPU, complex64); |
4167 | REGISTER_KERNELS(GPU, complex128); |
4168 | #endif |
4169 | #undef REGISTER_CPU_KERNELS |
4170 | #undef REGISTER_KERNELS |
4171 | |
4172 | // Note, this op works on cpu only. |
4173 | template <typename T, typename Tindex> |
4174 | class SparseApplyRMSPropOp : public OpKernel { |
4175 | public: |
4176 | explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
4177 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
4178 | } |
4179 | |
4180 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
4181 | const bool sparse = true; |
4182 | auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>( |
4183 | ctx, use_exclusive_lock_, sparse, {0, 1, 2}); |
4184 | |
4185 | Tensor var; |
4186 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
4187 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
4188 | Tensor ms; |
4189 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
4190 | ctx, 1, use_exclusive_lock_, sparse, &ms)); |
4191 | Tensor mom; |
4192 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
4193 | ctx, 2, use_exclusive_lock_, sparse, &mom)); |
4194 | |
4195 | OP_REQUIRES( |
4196 | ctx, var.IsInitialized(), |
4197 | errors::FailedPrecondition( |
4198 | "Attempting to use uninitialized variables: " , requested_input(0))); |
4199 | OP_REQUIRES( |
4200 | ctx, ms.IsInitialized(), |
4201 | errors::FailedPrecondition( |
4202 | "Attempting to use uninitialized variables: " , requested_input(1))); |
4203 | OP_REQUIRES( |
4204 | ctx, mom.IsInitialized(), |
4205 | errors::FailedPrecondition( |
4206 | "Attempting to use uninitialized variables: " , requested_input(2))); |
4207 | |
4208 | const Tensor& lr = ctx->input(3); |
4209 | const Tensor& rho = ctx->input(4); |
4210 | const Tensor& momentum = ctx->input(5); |
4211 | const Tensor& epsilon = ctx->input(6); |
4212 | const Tensor& grad = ctx->input(7); |
4213 | const Tensor& indices = ctx->input(8); |
4214 | |
4215 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
4216 | errors::InvalidArgument("lr is not a scalar: " , |
4217 | lr.shape().DebugString())); |
4218 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), |
4219 | errors::InvalidArgument("rho is not a scalar: " , |
4220 | rho.shape().DebugString())); |
4221 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
4222 | errors::InvalidArgument("momentum is not a scalar: " , |
4223 | momentum.shape().DebugString())); |
4224 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
4225 | errors::InvalidArgument("epsilon is not a scalar: " , |
4226 | epsilon.shape().DebugString())); |
4227 | |
4228 | OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), |
4229 | errors::InvalidArgument("var and ms do not have the same shape" , |
4230 | var.shape().DebugString(), " " , |
4231 | ms.shape().DebugString())); |
4232 | |
4233 | OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), |
4234 | errors::InvalidArgument( |
4235 | "var and mom do not have the same shape" , |
4236 | var.shape().DebugString(), " " , mom.shape().DebugString())); |
4237 | |
4238 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
4239 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
4240 | |
4241 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
4242 | errors::InvalidArgument("indices must be one-dimensional" )); |
4243 | |
4244 | for (int d = 1; d < var.dims(); d++) { |
4245 | OP_REQUIRES( |
4246 | ctx, var.dim_size(d) == grad.dim_size(d), |
4247 | errors::InvalidArgument("var and grad must match in dimension " , d)); |
4248 | } |
4249 | const Tindex N = indices.dim_size(0); |
4250 | OP_REQUIRES( |
4251 | ctx, grad.dim_size(0) == N, |
4252 | errors::InvalidArgument( |
4253 | "grad must be the same size as indices in the first dimension." )); |
4254 | |
4255 | if (N > 0) { |
4256 | const Tindex first_dim_size = var.dim_size(0); |
4257 | // Validate all the indices are in range |
4258 | auto indices_vec = indices.vec<Tindex>(); |
4259 | for (Tindex i = 0; i < N; i++) { |
4260 | const Tindex index = indices_vec(i); |
4261 | OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, |
4262 | errors::InvalidArgument( |
4263 | strings::StrCat("Index " , index, " at offset " , i, |
4264 | " in indices is out of range" ))); |
4265 | } |
4266 | |
4267 | auto var_flat = var.flat_outer_dims<T>(); |
4268 | auto ms_flat = ms.flat_outer_dims<T>(); |
4269 | auto mom_flat = mom.flat_outer_dims<T>(); |
4270 | auto grad_flat = grad.flat_outer_dims<T>(); |
4271 | const T lr_scalar = lr.scalar<T>()(); |
4272 | const T rho_scalar = rho.scalar<T>()(); |
4273 | const T epsilon_scalar = epsilon.scalar<T>()(); |
4274 | const T momentum_scalar = momentum.scalar<T>()(); |
4275 | |
4276 | for (Tindex i = 0; i < N; i++) { |
4277 | const Tindex index = indices_vec(i); |
4278 | |
4279 | auto ms_ = ms_flat.template chip<0>(index); |
4280 | auto mom_ = mom_flat.template chip<0>(index); |
4281 | auto grad_ = grad_flat.template chip<0>(i); |
4282 | |
4283 | ms_ = ms_ * ms_.constant(rho_scalar) + |
4284 | grad_.square() * grad_.constant(T(1) - rho_scalar); |
4285 | mom_ = mom_ * mom_.constant(momentum_scalar) + |
4286 | (ms_ + ms_.constant(epsilon_scalar)).rsqrt() * |
4287 | ms_.constant(lr_scalar) * grad_; |
4288 | |
4289 | auto v = var_flat.template chip<0>(index); |
4290 | v -= mom_; |
4291 | } |
4292 | } |
4293 | |
4294 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
4295 | } |
4296 | |
4297 | private: |
4298 | bool use_exclusive_lock_; |
4299 | }; |
4300 | |
4301 | // Note, this op works on cpu only. |
4302 | template <typename T, typename Tindex> |
4303 | class SparseApplyCenteredRMSPropOp : public OpKernel { |
4304 | public: |
4305 | explicit SparseApplyCenteredRMSPropOp(OpKernelConstruction* ctx) |
4306 | : OpKernel(ctx) { |
4307 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
4308 | } |
4309 | |
4310 | void Compute(OpKernelContext* ctx) override TF_NO_THREAD_SAFETY_ANALYSIS { |
4311 | const bool sparse = true; |
4312 | auto locks = MaybeLockVariableInputMutexesInOrder<CPUDevice, T>( |
4313 | ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3}); |
4314 | |
4315 | Tensor var; |
4316 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
4317 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
4318 | Tensor mg; |
4319 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
4320 | ctx, 1, use_exclusive_lock_, sparse, &mg)); |
4321 | Tensor ms; |
4322 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
4323 | ctx, 2, use_exclusive_lock_, sparse, &ms)); |
4324 | Tensor mom; |
4325 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( |
4326 | ctx, 3, use_exclusive_lock_, sparse, &mom)); |
4327 | |
4328 | OP_REQUIRES( |
4329 | ctx, var.IsInitialized(), |
4330 | errors::FailedPrecondition( |
4331 | "Attempting to use uninitialized variables: " , requested_input(0))); |
4332 | OP_REQUIRES( |
4333 | ctx, ms.IsInitialized(), |
4334 | errors::FailedPrecondition( |
4335 | "Attempting to use uninitialized variables: " , requested_input(2))); |
4336 | OP_REQUIRES( |
4337 | ctx, mom.IsInitialized(), |
4338 | errors::FailedPrecondition( |
4339 | "Attempting to use uninitialized variables: " , requested_input(3))); |
4340 | |
4341 | const Tensor& lr = ctx->input(4); |
4342 | const Tensor& rho = ctx->input(5); |
4343 | const Tensor& momentum = ctx->input(6); |
4344 | const Tensor& epsilon = ctx->input(7); |
4345 | const Tensor& grad = ctx->input(8); |
4346 | const Tensor& indices = ctx->input(9); |
4347 | |
4348 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
4349 | errors::InvalidArgument("lr is not a scalar: " , |
4350 | lr.shape().DebugString())); |
4351 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), |
4352 | errors::InvalidArgument("rho is not a scalar: " , |
4353 | rho.shape().DebugString())); |
4354 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), |
4355 | errors::InvalidArgument("momentum is not a scalar: " , |
4356 | momentum.shape().DebugString())); |
4357 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), |
4358 | errors::InvalidArgument("epsilon is not a scalar: " , |
4359 | epsilon.shape().DebugString())); |
4360 | |
4361 | OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()), |
4362 | errors::InvalidArgument("var and mg do not have the same shape" , |
4363 | var.shape().DebugString(), " " , |
4364 | mg.shape().DebugString())); |
4365 | |
4366 | OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), |
4367 | errors::InvalidArgument("var and ms do not have the same shape" , |
4368 | var.shape().DebugString(), " " , |
4369 | ms.shape().DebugString())); |
4370 | |
4371 | OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), |
4372 | errors::InvalidArgument( |
4373 | "var and mom do not have the same shape" , |
4374 | var.shape().DebugString(), " " , mom.shape().DebugString())); |
4375 | |
4376 | OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), |
4377 | errors::InvalidArgument("var must be at least 1 dimensional" )); |
4378 | |
4379 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), |
4380 | errors::InvalidArgument("indices must be one-dimensional" )); |
4381 | |
4382 | for (int d = 1; d < var.dims(); d++) { |
4383 | OP_REQUIRES( |
4384 | ctx, var.dim_size(d) == grad.dim_size(d), |
4385 | errors::InvalidArgument("var and grad must match in dimension " , d)); |
4386 | } |
4387 | const Tindex N = indices.dim_size(0); |
4388 | OP_REQUIRES( |
4389 | ctx, grad.dim_size(0) == N, |
4390 | errors::InvalidArgument( |
4391 | "grad must be the same size as indices in the first dimension." )); |
4392 | |
4393 | if (N > 0) { |
4394 | const Tindex first_dim_size = var.dim_size(0); |
4395 | // Validate all the indices are in range |
4396 | auto indices_vec = indices.vec<Tindex>(); |
4397 | for (Tindex i = 0; i < N; i++) { |
4398 | const Tindex index = indices_vec(i); |
4399 | OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, |
4400 | errors::InvalidArgument( |
4401 | strings::StrCat("Index " , index, " at offset " , i, |
4402 | " in indices is out of range" ))); |
4403 | } |
4404 | |
4405 | auto var_flat = var.flat_outer_dims<T>(); |
4406 | auto ms_flat = ms.flat_outer_dims<T>(); |
4407 | auto mg_flat = mg.flat_outer_dims<T>(); |
4408 | auto mom_flat = mom.flat_outer_dims<T>(); |
4409 | auto grad_flat = grad.flat_outer_dims<T>(); |
4410 | const T lr_scalar = lr.scalar<T>()(); |
4411 | const T rho_scalar = rho.scalar<T>()(); |
4412 | const T epsilon_scalar = epsilon.scalar<T>()(); |
4413 | const T momentum_scalar = momentum.scalar<T>()(); |
4414 | |
4415 | for (Tindex i = 0; i < N; i++) { |
4416 | const Tindex index = indices_vec(i); |
4417 | |
4418 | auto ms_ = ms_flat.template chip<0>(index); |
4419 | auto mom_ = mom_flat.template chip<0>(index); |
4420 | auto grad_ = grad_flat.template chip<0>(i); |
4421 | |
4422 | ms_ = ms_ * ms_.constant(rho_scalar) + |
4423 | grad_.square() * grad_.constant(T(1) - rho_scalar); |
4424 | |
4425 | auto mg_ = mg_flat.template chip<0>(index); |
4426 | mg_ = mg_ * mg_.constant(rho_scalar) + |
4427 | grad_ * grad_.constant(T(1) - rho_scalar); |
4428 | auto denom_ = ms_ + ms_.constant(epsilon_scalar) - mg_.square(); |
4429 | mom_ = mom_ * mom_.constant(momentum_scalar) + |
4430 | denom_.rsqrt() * ms_.constant(lr_scalar) * grad_; |
4431 | auto v = var_flat.template chip<0>(index); |
4432 | v -= mom_; |
4433 | } |
4434 | } |
4435 | |
4436 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
4437 | } |
4438 | |
4439 | private: |
4440 | bool use_exclusive_lock_; |
4441 | }; |
4442 | |
4443 | #define REGISTER_KERNELS(T, Tindices) \ |
4444 | REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \ |
4445 | .Device(DEVICE_CPU) \ |
4446 | .TypeConstraint<T>("T") \ |
4447 | .TypeConstraint<Tindices>("Tindices"), \ |
4448 | SparseApplyRMSPropOp<T, Tindices>); \ |
4449 | REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \ |
4450 | .Device(DEVICE_CPU) \ |
4451 | .TypeConstraint<T>("T") \ |
4452 | .TypeConstraint<Tindices>("Tindices"), \ |
4453 | SparseApplyCenteredRMSPropOp<T, Tindices>); \ |
4454 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyRMSProp") \ |
4455 | .Device(DEVICE_CPU) \ |
4456 | .TypeConstraint<T>("T") \ |
4457 | .TypeConstraint<Tindices>("Tindices"), \ |
4458 | SparseApplyRMSPropOp<T, Tindices>); \ |
4459 | REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyCenteredRMSProp") \ |
4460 | .Device(DEVICE_CPU) \ |
4461 | .TypeConstraint<T>("T") \ |
4462 | .TypeConstraint<Tindices>("Tindices"), \ |
4463 | SparseApplyCenteredRMSPropOp<T, Tindices>); |
4464 | |
4465 | REGISTER_KERNELS(Eigen::half, int32); |
4466 | REGISTER_KERNELS(Eigen::half, int64_t); |
4467 | REGISTER_KERNELS(float, int32); |
4468 | REGISTER_KERNELS(float, int64_t); |
4469 | REGISTER_KERNELS(double, int32); |
4470 | REGISTER_KERNELS(double, int64_t); |
4471 | REGISTER_KERNELS(complex64, int32); |
4472 | REGISTER_KERNELS(complex64, int64_t); |
4473 | REGISTER_KERNELS(complex128, int32); |
4474 | REGISTER_KERNELS(complex128, int64_t); |
4475 | |
4476 | #undef REGISTER_KERNELS |
4477 | |
4478 | template <typename Device, typename T> |
4479 | class ApplyAddSignOp : public OpKernel { |
4480 | public: |
4481 | explicit ApplyAddSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
4482 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
4483 | } |
4484 | |
4485 | void Compute(OpKernelContext* ctx) override { |
4486 | const bool sparse = false; |
4487 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
4488 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
4489 | |
4490 | Tensor var; |
4491 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4492 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
4493 | Tensor m; |
4494 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4495 | ctx, 1, use_exclusive_lock_, sparse, &m)); |
4496 | OP_REQUIRES( |
4497 | ctx, var.IsInitialized(), |
4498 | errors::FailedPrecondition( |
4499 | "Attempting to use uninitialized variables: " , requested_input(0))); |
4500 | OP_REQUIRES( |
4501 | ctx, m.IsInitialized(), |
4502 | errors::FailedPrecondition( |
4503 | "Attempting to use uninitialized variables: " , requested_input(1))); |
4504 | const Tensor& lr = ctx->input(2); |
4505 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
4506 | errors::InvalidArgument("lr is not a scalar: " , |
4507 | lr.shape().DebugString())); |
4508 | const Tensor& alpha = ctx->input(3); |
4509 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), |
4510 | errors::InvalidArgument("alpha is not a scalar: " , |
4511 | alpha.shape().DebugString())); |
4512 | const Tensor& sign_decay = ctx->input(4); |
4513 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), |
4514 | errors::InvalidArgument("sign_decay is not a scalar: " , |
4515 | sign_decay.shape().DebugString())); |
4516 | const Tensor& beta = ctx->input(5); |
4517 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()), |
4518 | errors::InvalidArgument("beta is not a scalar: " , |
4519 | beta.shape().DebugString())); |
4520 | const Tensor& grad = ctx->input(6); |
4521 | OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), |
4522 | errors::InvalidArgument("var and m do not have the same shape" , |
4523 | var.shape().DebugString(), " " , |
4524 | m.shape().DebugString())); |
4525 | OP_REQUIRES( |
4526 | ctx, var.shape().IsSameSize(grad.shape()), |
4527 | errors::InvalidArgument("var and grad do not have the same shape" , |
4528 | var.shape().DebugString(), " " , |
4529 | grad.shape().DebugString())); |
4530 | |
4531 | const Device& device = ctx->template eigen_device<Device>(); |
4532 | functor::ApplyAddSign<Device, T>()( |
4533 | device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), alpha.scalar<T>(), |
4534 | sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>()); |
4535 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
4536 | } |
4537 | |
4538 | private: |
4539 | bool use_exclusive_lock_; |
4540 | }; |
4541 | |
4542 | #define REGISTER_KERNELS(D, T) \ |
4543 | REGISTER_KERNEL_BUILDER( \ |
4544 | Name("ApplyAddSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
4545 | ApplyAddSignOp<D##Device, T>); \ |
4546 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyAddSign") \ |
4547 | .Device(DEVICE_##D) \ |
4548 | .HostMemory("var") \ |
4549 | .HostMemory("m") \ |
4550 | .TypeConstraint<T>("T"), \ |
4551 | ApplyAddSignOp<D##Device, T>); |
4552 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
4553 | |
4554 | TF_CALL_half(REGISTER_CPU_KERNELS); |
4555 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
4556 | TF_CALL_float(REGISTER_CPU_KERNELS); |
4557 | TF_CALL_double(REGISTER_CPU_KERNELS); |
4558 | |
4559 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
4560 | // Forward declarations of the functor specializations for GPU. |
4561 | namespace functor { |
4562 | #define DECLARE_GPU_SPEC(T) \ |
4563 | template <> \ |
4564 | void ApplyAddSign<GPUDevice, T>::operator()( \ |
4565 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
4566 | typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \ |
4567 | typename TTypes<T>::ConstScalar alpha, \ |
4568 | typename TTypes<T>::ConstScalar sign_decay, \ |
4569 | typename TTypes<T>::ConstScalar beta, \ |
4570 | typename TTypes<T>::ConstFlat grad); \ |
4571 | extern template struct ApplyAddSign<GPUDevice, T>; |
4572 | DECLARE_GPU_SPEC(Eigen::half); |
4573 | DECLARE_GPU_SPEC(float); |
4574 | DECLARE_GPU_SPEC(double); |
4575 | #undef DECLARE_GPU_SPEC |
4576 | } // namespace functor |
4577 | |
4578 | REGISTER_KERNELS(GPU, Eigen::half); |
4579 | REGISTER_KERNELS(GPU, float); |
4580 | REGISTER_KERNELS(GPU, double); |
4581 | #endif |
4582 | #undef REGISTER_CPU_KERNELS |
4583 | #undef REGISTER_KERNELS |
4584 | |
4585 | template <typename Device, typename T> |
4586 | class ApplyPowerSignOp : public OpKernel { |
4587 | public: |
4588 | explicit ApplyPowerSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
4589 | OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking" , &use_exclusive_lock_)); |
4590 | } |
4591 | |
4592 | void Compute(OpKernelContext* ctx) override { |
4593 | const bool sparse = false; |
4594 | auto locks = MaybeLockVariableInputMutexesInOrder<Device, T>( |
4595 | ctx, use_exclusive_lock_, sparse, {0, 1}); |
4596 | |
4597 | Tensor var; |
4598 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4599 | ctx, 0, use_exclusive_lock_, sparse, &var)); |
4600 | Tensor m; |
4601 | OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( |
4602 | ctx, 1, use_exclusive_lock_, sparse, &m)); |
4603 | OP_REQUIRES( |
4604 | ctx, var.IsInitialized(), |
4605 | errors::FailedPrecondition( |
4606 | "Attempting to use uninitialized variables: " , requested_input(0))); |
4607 | OP_REQUIRES( |
4608 | ctx, m.IsInitialized(), |
4609 | errors::FailedPrecondition( |
4610 | "Attempting to use uninitialized variables: " , requested_input(1))); |
4611 | const Tensor& lr = ctx->input(2); |
4612 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), |
4613 | errors::InvalidArgument("lr is not a scalar: " , |
4614 | lr.shape().DebugString())); |
4615 | const Tensor& logbase = ctx->input(3); |
4616 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()), |
4617 | errors::InvalidArgument("logbase is not a scalar: " , |
4618 | logbase.shape().DebugString())); |
4619 | const Tensor& sign_decay = ctx->input(4); |
4620 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()), |
4621 | errors::InvalidArgument("sign_decay is not a scalar: " , |
4622 | sign_decay.shape().DebugString())); |
4623 | const Tensor& beta = ctx->input(5); |
4624 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()), |
4625 | errors::InvalidArgument("beta is not a scalar: " , |
4626 | beta.shape().DebugString())); |
4627 | const Tensor& grad = ctx->input(6); |
4628 | OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), |
4629 | errors::InvalidArgument("var and m do not have the same shape" , |
4630 | var.shape().DebugString(), " " , |
4631 | m.shape().DebugString())); |
4632 | OP_REQUIRES( |
4633 | ctx, var.shape().IsSameSize(grad.shape()), |
4634 | errors::InvalidArgument("var and grad do not have the same shape" , |
4635 | var.shape().DebugString(), " " , |
4636 | grad.shape().DebugString())); |
4637 | |
4638 | const Device& device = ctx->template eigen_device<Device>(); |
4639 | functor::ApplyPowerSign<Device, T>()( |
4640 | device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), logbase.scalar<T>(), |
4641 | sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>()); |
4642 | MaybeForwardRefInputToRefOutput(ctx, 0, 0); |
4643 | } |
4644 | |
4645 | private: |
4646 | bool use_exclusive_lock_; |
4647 | }; |
4648 | |
4649 | #define REGISTER_KERNELS(D, T) \ |
4650 | REGISTER_KERNEL_BUILDER( \ |
4651 | Name("ApplyPowerSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \ |
4652 | ApplyPowerSignOp<D##Device, T>); \ |
4653 | REGISTER_KERNEL_BUILDER(Name("ResourceApplyPowerSign") \ |
4654 | .Device(DEVICE_##D) \ |
4655 | .HostMemory("var") \ |
4656 | .HostMemory("m") \ |
4657 | .TypeConstraint<T>("T"), \ |
4658 | ApplyPowerSignOp<D##Device, T>); |
4659 | #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); |
4660 | |
4661 | TF_CALL_half(REGISTER_CPU_KERNELS); |
4662 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
4663 | TF_CALL_float(REGISTER_CPU_KERNELS); |
4664 | TF_CALL_double(REGISTER_CPU_KERNELS); |
4665 | |
4666 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
4667 | // Forward declarations of the functor specializations for GPU. |
4668 | namespace functor { |
4669 | #define DECLARE_GPU_SPEC(T) \ |
4670 | template <> \ |
4671 | void ApplyPowerSign<GPUDevice, T>::operator()( \ |
4672 | const GPUDevice& d, typename TTypes<T>::Flat var, \ |
4673 | typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \ |
4674 | typename TTypes<T>::ConstScalar logbase, \ |
4675 | typename TTypes<T>::ConstScalar sign_decay, \ |
4676 | typename TTypes<T>::ConstScalar beta, \ |
4677 | typename TTypes<T>::ConstFlat grad); \ |
4678 | extern template struct ApplyPowerSign<GPUDevice, T>; |
4679 | DECLARE_GPU_SPEC(Eigen::half); |
4680 | DECLARE_GPU_SPEC(float); |
4681 | DECLARE_GPU_SPEC(double); |
4682 | #undef DECLARE_GPU_SPEC |
4683 | } // namespace functor |
4684 | |
4685 | REGISTER_KERNELS(GPU, Eigen::half); |
4686 | REGISTER_KERNELS(GPU, float); |
4687 | REGISTER_KERNELS(GPU, double); |
4688 | #endif |
4689 | #undef REGISTER_CPU_KERNELS |
4690 | #undef REGISTER_KERNELS |
4691 | |
4692 | } // namespace tensorflow |
4693 | |