1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/shape_inference.h"
18
19namespace tensorflow {
20
21using shape_inference::DimensionHandle;
22using shape_inference::InferenceContext;
23using shape_inference::ShapeHandle;
24
25template <bool is_resource>
26ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
27 auto* handle_data = c->input_handle_shapes_and_types(input);
28 if (handle_data != nullptr && !handle_data->empty() &&
29 (*handle_data)[0].dtype != DT_INVALID) {
30 return (*handle_data)[0].shape;
31 }
32 return c->input(input);
33}
34
35template <>
36ShapeHandle ShapeOrHandleShape<true>(InferenceContext* c, int input) {
37 auto* handle_data = c->input_handle_shapes_and_types(input);
38 if (handle_data != nullptr && !handle_data->empty() &&
39 (*handle_data)[0].dtype != DT_INVALID) {
40 return (*handle_data)[0].shape;
41 }
42 // If a resource input is missing shape information, we should return
43 // UnknownShape rather than the shape of the input, which is a scalar
44 // resource handle.
45 return c->UnknownShape();
46}
47
48// Handle the gradient and, if <is_sparse>, indices inputs.
49// <s> is an input+output parameter, containing the current known input shape to
50// the gradient.
51template <bool is_sparse, bool is_resource>
52static Status HandleGradAndIndicesInputs(InferenceContext* c, int grad_idx,
53 ShapeHandle* s) {
54 ShapeHandle grad = ShapeOrHandleShape<is_resource>(c, grad_idx);
55 if (!is_sparse) {
56 TF_RETURN_IF_ERROR(c->Merge(*s, grad, s));
57 return OkStatus();
58 }
59 // Indices is a vector where indices.dim[0].rank == grad[0].rank.
60 ShapeHandle indices;
61 TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices));
62 DimensionHandle unused;
63 TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
64 // Trailing part of grad matches trailing part of *s.
65 ShapeHandle grad_unknown_first;
66 TF_RETURN_IF_ERROR(
67 c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first));
68 TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s));
69
70 return OkStatus();
71}
72
73template <bool is_resource>
74static Status ApplyGradientDescentShapeFn(InferenceContext* c) {
75 ShapeHandle unused;
76 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
78 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // delta
79 if (c->num_outputs() > 0) {
80 c->set_output(0, s);
81 }
82 return OkStatus();
83}
84
85REGISTER_OP("ApplyGradientDescent")
86 .Input("var: Ref(T)")
87 .Input("alpha: T")
88 .Input("delta: T")
89 .Output("out: Ref(T)")
90 .Attr("T: numbertype")
91 .Attr("use_locking: bool = false")
92 .SetShapeFn(ApplyGradientDescentShapeFn<false>);
93
94REGISTER_OP("ResourceApplyGradientDescent")
95 .Input("var: resource")
96 .Input("alpha: T")
97 .Input("delta: T")
98 .Attr("T: numbertype")
99 .Attr("use_locking: bool = false")
100 .SetShapeFn(ApplyGradientDescentShapeFn<true>);
101
102template <bool is_sparse, bool is_resource>
103Status ApplyProximalGradientDescentShapeFn(InferenceContext* c) {
104 ShapeHandle unused;
105 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
106 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
107 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // l1
108 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l2
109 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
110 c, 4 /* grad_idx */, &s));
111 if (c->num_outputs() > 0) {
112 c->set_output(0, s);
113 }
114 return OkStatus();
115}
116
117REGISTER_OP("ApplyProximalGradientDescent")
118 .Input("var: Ref(T)")
119 .Input("alpha: T")
120 .Input("l1: T")
121 .Input("l2: T")
122 .Input("delta: T")
123 .Output("out: Ref(T)")
124 .Attr("T: numbertype")
125 .Attr("use_locking: bool = false")
126 .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/false,
127 /*is_resource=*/false>);
128
129REGISTER_OP("SparseApplyProximalGradientDescent")
130 .Input("var: Ref(T)")
131 .Input("alpha: T")
132 .Input("l1: T")
133 .Input("l2: T")
134 .Input("grad: T")
135 .Input("indices: Tindices")
136 .Output("out: Ref(T)")
137 .Attr("T: numbertype")
138 .Attr("Tindices: {int32, int64}")
139 .Attr("use_locking: bool = false")
140 .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/true,
141 /*is_resource=*/false>);
142
143REGISTER_OP("ResourceApplyProximalGradientDescent")
144 .Input("var: resource")
145 .Input("alpha: T")
146 .Input("l1: T")
147 .Input("l2: T")
148 .Input("delta: T")
149 .Attr("T: numbertype")
150 .Attr("use_locking: bool = false")
151 .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/false,
152 /*is_resource=*/true>);
153
154REGISTER_OP("ResourceSparseApplyProximalGradientDescent")
155 .Input("var: resource")
156 .Input("alpha: T")
157 .Input("l1: T")
158 .Input("l2: T")
159 .Input("grad: T")
160 .Input("indices: Tindices")
161 .Attr("T: numbertype")
162 .Attr("Tindices: {int32, int64}")
163 .Attr("use_locking: bool = false")
164 .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/true,
165 /*is_resource=*/true>);
166
167template <bool is_sparse, bool is_resource>
168static Status ApplyAdadeltaShapeFn(InferenceContext* c) {
169 ShapeHandle unused;
170 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
171 TF_RETURN_IF_ERROR(
172 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum
173 TF_RETURN_IF_ERROR(
174 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // accum update
175 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
176 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
177 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // epsilon
178 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
179 c, 6 /* grad_idx */, &s));
180 if (c->num_outputs() > 0) {
181 c->set_output(0, s);
182 }
183 return OkStatus();
184}
185
186REGISTER_OP("ApplyAdadelta")
187 .Input("var: Ref(T)")
188 .Input("accum: Ref(T)")
189 .Input("accum_update: Ref(T)")
190 .Input("lr: T")
191 .Input("rho: T")
192 .Input("epsilon: T")
193 .Input("grad: T")
194 .Output("out: Ref(T)")
195 .Attr("T: numbertype")
196 .Attr("use_locking: bool = false")
197 .SetShapeFn(
198 ApplyAdadeltaShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
199
200REGISTER_OP("SparseApplyAdadelta")
201 .Input("var: Ref(T)")
202 .Input("accum: Ref(T)")
203 .Input("accum_update: Ref(T)")
204 .Input("lr: T")
205 .Input("rho: T")
206 .Input("epsilon: T")
207 .Input("grad: T")
208 .Input("indices: Tindices")
209 .Output("out: Ref(T)")
210 .Attr("T: numbertype")
211 .Attr("Tindices: {int32, int64}")
212 .Attr("use_locking: bool = false")
213 .SetShapeFn(
214 ApplyAdadeltaShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
215
216REGISTER_OP("ResourceApplyAdadelta")
217 .Input("var: resource")
218 .Input("accum: resource")
219 .Input("accum_update: resource")
220 .Input("lr: T")
221 .Input("rho: T")
222 .Input("epsilon: T")
223 .Input("grad: T")
224 .Attr("T: numbertype")
225 .Attr("use_locking: bool = false")
226 .SetShapeFn(
227 ApplyAdadeltaShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
228
229REGISTER_OP("ResourceSparseApplyAdadelta")
230 .Input("var: resource")
231 .Input("accum: resource")
232 .Input("accum_update: resource")
233 .Input("lr: T")
234 .Input("rho: T")
235 .Input("epsilon: T")
236 .Input("grad: T")
237 .Input("indices: Tindices")
238 .Attr("T: numbertype")
239 .Attr("Tindices: {int32, int64}")
240 .Attr("use_locking: bool = false")
241 .SetShapeFn(ApplyAdadeltaShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
242
243template <bool is_sparse, bool is_resource>
244static Status ApplyAdagradShapeFn(InferenceContext* c) {
245 ShapeHandle unused;
246 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
247 TF_RETURN_IF_ERROR(
248 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum
249 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
250 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
251 c, 3 /* grad_idx */, &s));
252 if (c->num_outputs() > 0) {
253 c->set_output(0, s);
254 }
255 return OkStatus();
256}
257
258REGISTER_OP("ApplyAdagrad")
259 .Input("var: Ref(T)")
260 .Input("accum: Ref(T)")
261 .Input("lr: T")
262 .Input("grad: T")
263 .Output("out: Ref(T)")
264 .Attr("T: numbertype")
265 .Attr("use_locking: bool = false")
266 .Attr("update_slots: bool = true")
267 .SetShapeFn(
268 ApplyAdagradShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
269
270REGISTER_OP("ResourceApplyAdagrad")
271 .Input("var: resource")
272 .Input("accum: resource")
273 .Input("lr: T")
274 .Input("grad: T")
275 .Attr("T: numbertype")
276 .Attr("use_locking: bool = false")
277 .Attr("update_slots: bool = true")
278 .SetShapeFn(ApplyAdagradShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
279
280REGISTER_OP("SparseApplyAdagrad")
281 .Input("var: Ref(T)")
282 .Input("accum: Ref(T)")
283 .Input("lr: T")
284 .Input("grad: T")
285 .Input("indices: Tindices")
286 .Output("out: Ref(T)")
287 .Attr("T: numbertype")
288 .Attr("Tindices: {int32, int64}")
289 .Attr("use_locking: bool = false")
290 .Attr("update_slots: bool = true")
291 .SetShapeFn(ApplyAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
292
293REGISTER_OP("ResourceSparseApplyAdagrad")
294 .Input("var: resource")
295 .Input("accum: resource")
296 .Input("lr: T")
297 .Input("grad: T")
298 .Input("indices: Tindices")
299 .Attr("T: numbertype")
300 .Attr("Tindices: {int32, int64}")
301 .Attr("use_locking: bool = false")
302 .Attr("update_slots: bool = true")
303 .SetShapeFn(ApplyAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
304
305template <bool is_sparse, bool is_resource>
306static Status ApplyAdagradV2ShapeFn(InferenceContext* c) {
307 ShapeHandle unused;
308 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
309 TF_RETURN_IF_ERROR(
310 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum
311 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
312 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // epsilon
313 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
314 c, 4 /* grad_idx */, &s));
315 if (c->num_outputs() > 0) {
316 c->set_output(0, s);
317 }
318 return OkStatus();
319}
320
321REGISTER_OP("ApplyAdagradV2")
322 .Input("var: Ref(T)")
323 .Input("accum: Ref(T)")
324 .Input("lr: T")
325 .Input("epsilon: T")
326 .Input("grad: T")
327 .Output("out: Ref(T)")
328 .Attr("T: numbertype")
329 .Attr("use_locking: bool = false")
330 .Attr("update_slots: bool = true")
331 .SetShapeFn(
332 ApplyAdagradV2ShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
333
334REGISTER_OP("ResourceApplyAdagradV2")
335 .Input("var: resource")
336 .Input("accum: resource")
337 .Input("lr: T")
338 .Input("epsilon: T")
339 .Input("grad: T")
340 .Attr("T: numbertype")
341 .Attr("use_locking: bool = false")
342 .Attr("update_slots: bool = true")
343 .SetShapeFn(
344 ApplyAdagradV2ShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
345
346REGISTER_OP("SparseApplyAdagradV2")
347 .Input("var: Ref(T)")
348 .Input("accum: Ref(T)")
349 .Input("lr: T")
350 .Input("epsilon: T")
351 .Input("grad: T")
352 .Input("indices: Tindices")
353 .Output("out: Ref(T)")
354 .Attr("T: numbertype")
355 .Attr("Tindices: {int32, int64}")
356 .Attr("use_locking: bool = false")
357 .Attr("update_slots: bool = true")
358 .SetShapeFn(
359 ApplyAdagradV2ShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
360
361REGISTER_OP("ResourceSparseApplyAdagradV2")
362 .Input("var: resource")
363 .Input("accum: resource")
364 .Input("lr: T")
365 .Input("epsilon: T")
366 .Input("grad: T")
367 .Input("indices: Tindices")
368 .Attr("T: numbertype")
369 .Attr("Tindices: {int32, int64}")
370 .Attr("use_locking: bool = false")
371 .Attr("update_slots: bool = true")
372 .SetShapeFn(
373 ApplyAdagradV2ShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
374
375template <bool is_sparse, bool is_resource>
376static Status ApplyProximalAdagradShapeFn(InferenceContext* c) {
377 ShapeHandle unused;
378 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
379 TF_RETURN_IF_ERROR(
380 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum
381 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
382 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l1
383 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // l2
384 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
385 c, 5 /* grad_idx */, &s));
386 if (c->num_outputs() > 0) {
387 c->set_output(0, s);
388 }
389 return OkStatus();
390}
391
392REGISTER_OP("ApplyProximalAdagrad")
393 .Input("var: Ref(T)")
394 .Input("accum: Ref(T)")
395 .Input("lr: T")
396 .Input("l1: T")
397 .Input("l2: T")
398 .Input("grad: T")
399 .Output("out: Ref(T)")
400 .Attr("T: numbertype")
401 .Attr("use_locking: bool = false")
402 .SetShapeFn(ApplyProximalAdagradShapeFn</*is_sparse=*/false,
403 /*is_resource=*/false>);
404
405REGISTER_OP("ResourceApplyProximalAdagrad")
406 .Input("var: resource")
407 .Input("accum: resource")
408 .Input("lr: T")
409 .Input("l1: T")
410 .Input("l2: T")
411 .Input("grad: T")
412 .Attr("T: numbertype")
413 .Attr("use_locking: bool = false")
414 .SetShapeFn(ApplyProximalAdagradShapeFn</*is_sparse=*/false,
415 /*is_resource=*/false>);
416
417REGISTER_OP("SparseApplyProximalAdagrad")
418 .Input("var: Ref(T)")
419 .Input("accum: Ref(T)")
420 .Input("lr: T")
421 .Input("l1: T")
422 .Input("l2: T")
423 .Input("grad: T")
424 .Input("indices: Tindices")
425 .Output("out: Ref(T)")
426 .Attr("T: numbertype")
427 .Attr("Tindices: {int32, int64}")
428 .Attr("use_locking: bool = false")
429 .SetShapeFn(
430 ApplyProximalAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
431
432REGISTER_OP("ResourceSparseApplyProximalAdagrad")
433 .Input("var: resource")
434 .Input("accum: resource")
435 .Input("lr: T")
436 .Input("l1: T")
437 .Input("l2: T")
438 .Input("grad: T")
439 .Input("indices: Tindices")
440 .Attr("T: numbertype")
441 .Attr("Tindices: {int32, int64}")
442 .Attr("use_locking: bool = false")
443 .SetShapeFn(
444 ApplyProximalAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
445
446template <bool is_sparse, bool is_resource>
447static Status ApplyAdagradDAShapeFn(InferenceContext* c) {
448 ShapeHandle unused;
449 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
450 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1),
451 &s)); // grad_accumulator
452 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2),
453 &s)); // gradient_squared_accumulator
454 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
455 c, 3 /* grad_idx */, &s));
456 int idx = is_sparse ? 5 : 4;
457 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr
458 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1
459 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2
460 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // global step
461 if (c->num_outputs() > 0) {
462 c->set_output(0, s);
463 }
464 return OkStatus();
465}
466
467REGISTER_OP("ApplyAdagradDA")
468 .Input("var: Ref(T)")
469 .Input("gradient_accumulator: Ref(T)")
470 .Input("gradient_squared_accumulator: Ref(T)")
471 .Input("grad: T")
472 .Input("lr: T")
473 .Input("l1: T")
474 .Input("l2: T")
475 .Input("global_step: int64")
476 .Output("out: Ref(T)")
477 .Attr("T: numbertype")
478 .Attr("use_locking: bool = false")
479 .SetShapeFn(
480 ApplyAdagradDAShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
481
482REGISTER_OP("SparseApplyAdagradDA")
483 .Input("var: Ref(T)")
484 .Input("gradient_accumulator: Ref(T)")
485 .Input("gradient_squared_accumulator: Ref(T)")
486 .Input("grad: T")
487 .Input("indices: Tindices")
488 .Input("lr: T")
489 .Input("l1: T")
490 .Input("l2: T")
491 .Input("global_step: int64")
492 .Output("out: Ref(T)")
493 .Attr("T: numbertype")
494 .Attr("Tindices: {int32, int64}")
495 .Attr("use_locking: bool = false")
496 .SetShapeFn(
497 ApplyAdagradDAShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
498
499REGISTER_OP("ResourceApplyAdagradDA")
500 .Input("var: resource")
501 .Input("gradient_accumulator: resource")
502 .Input("gradient_squared_accumulator: resource")
503 .Input("grad: T")
504 .Input("lr: T")
505 .Input("l1: T")
506 .Input("l2: T")
507 .Input("global_step: int64")
508 .Attr("T: numbertype")
509 .Attr("use_locking: bool = false")
510 .SetShapeFn(
511 ApplyAdagradDAShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
512
513REGISTER_OP("ResourceSparseApplyAdagradDA")
514 .Input("var: resource")
515 .Input("gradient_accumulator: resource")
516 .Input("gradient_squared_accumulator: resource")
517 .Input("grad: T")
518 .Input("indices: Tindices")
519 .Input("lr: T")
520 .Input("l1: T")
521 .Input("l2: T")
522 .Input("global_step: int64")
523 .Attr("T: numbertype")
524 .Attr("Tindices: {int32, int64}")
525 .Attr("use_locking: bool = false")
526 .SetShapeFn(
527 ApplyAdagradDAShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
528
529template <bool is_sparse, bool is_resource>
530static Status ApplyFtrlShapeFn(InferenceContext* c) {
531 ShapeHandle unused;
532 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
533 TF_RETURN_IF_ERROR(
534 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum
535 TF_RETURN_IF_ERROR(
536 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // linear
537 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
538 c, 3 /* grad_idx */, &s));
539 int idx = is_sparse ? 5 : 4;
540 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr
541 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1
542 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2
543 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr_power
544 if (c->num_outputs() > 0) {
545 c->set_output(0, s);
546 }
547 return OkStatus();
548}
549
550REGISTER_OP("ApplyFtrl")
551 .Input("var: Ref(T)")
552 .Input("accum: Ref(T)")
553 .Input("linear: Ref(T)")
554 .Input("grad: T")
555 .Input("lr: T")
556 .Input("l1: T")
557 .Input("l2: T")
558 .Input("lr_power: T")
559 .Output("out: Ref(T)")
560 .Attr("T: numbertype")
561 .Attr("use_locking: bool = false")
562 .Attr("multiply_linear_by_lr: bool = false")
563 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
564
565REGISTER_OP("SparseApplyFtrl")
566 .Input("var: Ref(T)")
567 .Input("accum: Ref(T)")
568 .Input("linear: Ref(T)")
569 .Input("grad: T")
570 .Input("indices: Tindices")
571 .Input("lr: T")
572 .Input("l1: T")
573 .Input("l2: T")
574 .Input("lr_power: T")
575 .Output("out: Ref(T)")
576 .Attr("T: numbertype")
577 .Attr("Tindices: {int32, int64}")
578 .Attr("use_locking: bool = false")
579 .Attr("multiply_linear_by_lr: bool = false")
580 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
581
582REGISTER_OP("ResourceApplyFtrl")
583 .Input("var: resource")
584 .Input("accum: resource")
585 .Input("linear: resource")
586 .Input("grad: T")
587 .Input("lr: T")
588 .Input("l1: T")
589 .Input("l2: T")
590 .Input("lr_power: T")
591 .Attr("T: numbertype")
592 .Attr("use_locking: bool = false")
593 .Attr("multiply_linear_by_lr: bool = false")
594 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
595
596REGISTER_OP("ResourceSparseApplyFtrl")
597 .Input("var: resource")
598 .Input("accum: resource")
599 .Input("linear: resource")
600 .Input("grad: T")
601 .Input("indices: Tindices")
602 .Input("lr: T")
603 .Input("l1: T")
604 .Input("l2: T")
605 .Input("lr_power: T")
606 .Attr("T: numbertype")
607 .Attr("Tindices: {int32, int64}")
608 .Attr("use_locking: bool = false")
609 .Attr("multiply_linear_by_lr: bool = false")
610 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
611
612REGISTER_OP("ApplyFtrlV2")
613 .Input("var: Ref(T)")
614 .Input("accum: Ref(T)")
615 .Input("linear: Ref(T)")
616 .Input("grad: T")
617 .Input("lr: T")
618 .Input("l1: T")
619 .Input("l2: T")
620 .Input("l2_shrinkage: T")
621 .Input("lr_power: T")
622 .Output("out: Ref(T)")
623 .Attr("T: numbertype")
624 .Attr("use_locking: bool = false")
625 .Attr("multiply_linear_by_lr: bool = false")
626 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
627
628REGISTER_OP("SparseApplyFtrlV2")
629 .Input("var: Ref(T)")
630 .Input("accum: Ref(T)")
631 .Input("linear: Ref(T)")
632 .Input("grad: T")
633 .Input("indices: Tindices")
634 .Input("lr: T")
635 .Input("l1: T")
636 .Input("l2: T")
637 .Input("l2_shrinkage: T")
638 .Input("lr_power: T")
639 .Output("out: Ref(T)")
640 .Attr("T: numbertype")
641 .Attr("Tindices: {int32, int64}")
642 .Attr("use_locking: bool = false")
643 .Attr("multiply_linear_by_lr: bool = false")
644 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
645
646REGISTER_OP("ResourceApplyFtrlV2")
647 .Input("var: resource")
648 .Input("accum: resource")
649 .Input("linear: resource")
650 .Input("grad: T")
651 .Input("lr: T")
652 .Input("l1: T")
653 .Input("l2: T")
654 .Input("l2_shrinkage: T")
655 .Input("lr_power: T")
656 .Attr("T: numbertype")
657 .Attr("use_locking: bool = false")
658 .Attr("multiply_linear_by_lr: bool = false")
659 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
660
661REGISTER_OP("ResourceSparseApplyFtrlV2")
662 .Input("var: resource")
663 .Input("accum: resource")
664 .Input("linear: resource")
665 .Input("grad: T")
666 .Input("indices: Tindices")
667 .Input("lr: T")
668 .Input("l1: T")
669 .Input("l2: T")
670 .Input("l2_shrinkage: T")
671 .Input("lr_power: T")
672 .Attr("T: numbertype")
673 .Attr("Tindices: {int32, int64}")
674 .Attr("use_locking: bool = false")
675 .Attr("multiply_linear_by_lr: bool = false")
676 .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
677
678template <bool is_sparse, bool is_resource>
679static Status ApplyMomentumShapeFn(InferenceContext* c) {
680 ShapeHandle unused;
681 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
682 TF_RETURN_IF_ERROR(
683 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum
684 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
685 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
686 c, 3 /* grad_idx */, &s));
687 int idx = is_sparse ? 5 : 4;
688 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // momentum
689 if (c->num_outputs() > 0) {
690 c->set_output(0, s);
691 }
692 return OkStatus();
693}
694
695REGISTER_OP("ApplyMomentum")
696 .Input("var: Ref(T)")
697 .Input("accum: Ref(T)")
698 .Input("lr: T")
699 .Input("grad: T")
700 .Input("momentum: T")
701 .Output("out: Ref(T)")
702 .Attr("T: numbertype")
703 .Attr("use_locking: bool = false")
704 .Attr("use_nesterov: bool = false")
705 .SetShapeFn(
706 ApplyMomentumShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
707
708REGISTER_OP("SparseApplyMomentum")
709 .Input("var: Ref(T)")
710 .Input("accum: Ref(T)")
711 .Input("lr: T")
712 .Input("grad: T")
713 .Input("indices: Tindices")
714 .Input("momentum: T")
715 .Output("out: Ref(T)")
716 .Attr("T: numbertype")
717 .Attr("Tindices: {int32, int64}")
718 .Attr("use_locking: bool = false")
719 .Attr("use_nesterov: bool = false")
720 .SetShapeFn(
721 ApplyMomentumShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
722
723REGISTER_OP("ResourceApplyMomentum")
724 .Input("var: resource")
725 .Input("accum: resource")
726 .Input("lr: T")
727 .Input("grad: T")
728 .Input("momentum: T")
729 .Attr("T: numbertype")
730 .Attr("use_locking: bool = false")
731 .Attr("use_nesterov: bool = false")
732 .SetShapeFn(
733 ApplyMomentumShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
734
735REGISTER_OP("ResourceSparseApplyMomentum")
736 .Input("var: resource")
737 .Input("accum: resource")
738 .Input("lr: T")
739 .Input("grad: T")
740 .Input("indices: Tindices")
741 .Input("momentum: T")
742 .Attr("T: numbertype")
743 .Attr("Tindices: {int32, int64}")
744 .Attr("use_locking: bool = false")
745 .Attr("use_nesterov: bool = false")
746 .SetShapeFn(ApplyMomentumShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
747
748REGISTER_OP("ResourceApplyKerasMomentum")
749 .Input("var: resource")
750 .Input("accum: resource")
751 .Input("lr: T")
752 .Input("grad: T")
753 .Input("momentum: T")
754 .Attr("T: numbertype")
755 .Attr("use_locking: bool = false")
756 .Attr("use_nesterov: bool = false")
757 .SetShapeFn(
758 ApplyMomentumShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
759
760REGISTER_OP("ResourceSparseApplyKerasMomentum")
761 .Input("var: resource")
762 .Input("accum: resource")
763 .Input("lr: T")
764 .Input("grad: T")
765 .Input("indices: Tindices")
766 .Input("momentum: T")
767 .Attr("T: numbertype")
768 .Attr("Tindices: {int32, int64}")
769 .Attr("use_locking: bool = false")
770 .Attr("use_nesterov: bool = false")
771 .SetShapeFn(ApplyMomentumShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
772
773template <bool is_resource>
774static Status ApplyAdamShapeFn(InferenceContext* c) {
775 ShapeHandle unused;
776 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
777 TF_RETURN_IF_ERROR(
778 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m
779 TF_RETURN_IF_ERROR(
780 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // v
781 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
782 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta2_power
783 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // lr
784 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta1
785 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta2
786 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // epsilon
787 TF_RETURN_IF_ERROR(
788 HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>(
789 c, 9 /* grad_idx */, &s));
790 if (c->num_outputs() > 0) {
791 c->set_output(0, s);
792 }
793 return OkStatus();
794}
795
796REGISTER_OP("ApplyAdam")
797 .Input("var: Ref(T)")
798 .Input("m: Ref(T)")
799 .Input("v: Ref(T)")
800 .Input("beta1_power: T")
801 .Input("beta2_power: T")
802 .Input("lr: T")
803 .Input("beta1: T")
804 .Input("beta2: T")
805 .Input("epsilon: T")
806 .Input("grad: T")
807 .Output("out: Ref(T)")
808 .Attr("T: numbertype")
809 .Attr("use_locking: bool = false")
810 .Attr("use_nesterov: bool = false")
811 .SetShapeFn(ApplyAdamShapeFn</*is_resource=*/false>);
812
813REGISTER_OP("ResourceApplyAdam")
814 .Input("var: resource")
815 .Input("m: resource")
816 .Input("v: resource")
817 .Input("beta1_power: T")
818 .Input("beta2_power: T")
819 .Input("lr: T")
820 .Input("beta1: T")
821 .Input("beta2: T")
822 .Input("epsilon: T")
823 .Input("grad: T")
824 .Attr("T: numbertype")
825 .Attr("use_locking: bool = false")
826 .Attr("use_nesterov: bool = false")
827 .SetShapeFn(ApplyAdamShapeFn</*is_resource=*/true>);
828
829template <bool is_resource>
830static Status ApplyAdamWithAmsgradShapeFn(InferenceContext* c) {
831 ShapeHandle unused;
832 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
833 TF_RETURN_IF_ERROR(
834 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m
835 TF_RETURN_IF_ERROR(
836 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // v
837 TF_RETURN_IF_ERROR(
838 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 3), &s)); // vhat
839 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta1_power
840 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta2_power
841 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // lr
842 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta1
843 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // beta2
844 TF_RETURN_IF_ERROR(c->WithRank(c->input(9), 0, &unused)); // epsilon
845 TF_RETURN_IF_ERROR(
846 HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>(
847 c, 10 /* grad_idx */, &s));
848 if (c->num_outputs() > 0) {
849 c->set_output(0, s);
850 }
851 return OkStatus();
852}
853
854REGISTER_OP("ResourceApplyAdamWithAmsgrad")
855 .Input("var: resource")
856 .Input("m: resource")
857 .Input("v: resource")
858 .Input("vhat: resource")
859 .Input("beta1_power: T")
860 .Input("beta2_power: T")
861 .Input("lr: T")
862 .Input("beta1: T")
863 .Input("beta2: T")
864 .Input("epsilon: T")
865 .Input("grad: T")
866 .Attr("T: numbertype")
867 .Attr("use_locking: bool = false")
868 .SetShapeFn(ApplyAdamWithAmsgradShapeFn</*is_resource=*/true>);
869
870template <bool is_resource>
871static Status ApplyAdaMaxShapeFn(InferenceContext* c) {
872 ShapeHandle unused;
873 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
874 TF_RETURN_IF_ERROR(
875 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m
876 TF_RETURN_IF_ERROR(
877 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // v
878 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
879 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
880 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta1
881 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta2
882 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon
883 TF_RETURN_IF_ERROR(
884 HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>(
885 c, 8 /* grad_idx */, &s));
886 if (c->num_outputs() > 0) {
887 c->set_output(0, s);
888 }
889 return OkStatus();
890}
891
892REGISTER_OP("ApplyAdaMax")
893 .Input("var: Ref(T)")
894 .Input("m: Ref(T)")
895 .Input("v: Ref(T)")
896 .Input("beta1_power: T")
897 .Input("lr: T")
898 .Input("beta1: T")
899 .Input("beta2: T")
900 .Input("epsilon: T")
901 .Input("grad: T")
902 .Output("out: Ref(T)")
903 .Attr("T: numbertype")
904 .Attr("use_locking: bool = false")
905 .SetShapeFn(ApplyAdaMaxShapeFn</*is_resource=*/false>);
906
907REGISTER_OP("ResourceApplyAdaMax")
908 .Input("var: resource")
909 .Input("m: resource")
910 .Input("v: resource")
911 .Input("beta1_power: T")
912 .Input("lr: T")
913 .Input("beta1: T")
914 .Input("beta2: T")
915 .Input("epsilon: T")
916 .Input("grad: T")
917 .Attr("T: numbertype")
918 .Attr("use_locking: bool = false")
919 .SetShapeFn(ApplyAdaMaxShapeFn</*is_resource=*/true>);
920
921template <bool is_sparse, bool is_resource>
922static Status ApplyRMSPropShapeFn(InferenceContext* c) {
923 ShapeHandle unused;
924 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
925 TF_RETURN_IF_ERROR(
926 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // ms
927 TF_RETURN_IF_ERROR(
928 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // mom
929 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
930 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
931 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // momentum
932 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // epsilon
933 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
934 c, 7 /* grad_idx */, &s));
935 if (c->num_outputs() > 0) {
936 c->set_output(0, s);
937 }
938 return OkStatus();
939}
940
941REGISTER_OP("ApplyRMSProp")
942 .Input("var: Ref(T)")
943 .Input("ms: Ref(T)")
944 .Input("mom: Ref(T)")
945 .Input("lr: T")
946 .Input("rho: T")
947 .Input("momentum: T")
948 .Input("epsilon: T")
949 .Input("grad: T")
950 .Output("out: Ref(T)")
951 .Attr("T: numbertype")
952 .Attr("use_locking: bool = false")
953 .SetShapeFn(
954 ApplyRMSPropShapeFn</*is_sparse=*/false, /*is_resource=*/false>);
955
956REGISTER_OP("SparseApplyRMSProp")
957 .Input("var: Ref(T)")
958 .Input("ms: Ref(T)")
959 .Input("mom: Ref(T)")
960 .Input("lr: T")
961 .Input("rho: T")
962 .Input("momentum: T")
963 .Input("epsilon: T")
964 .Input("grad: T")
965 .Input("indices: Tindices")
966 .Output("out: Ref(T)")
967 .Attr("T: numbertype")
968 .Attr("Tindices: {int32, int64}")
969 .Attr("use_locking: bool = false")
970 .SetShapeFn(ApplyRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
971
972REGISTER_OP("ResourceApplyRMSProp")
973 .Input("var: resource")
974 .Input("ms: resource")
975 .Input("mom: resource")
976 .Input("lr: T")
977 .Input("rho: T")
978 .Input("momentum: T")
979 .Input("epsilon: T")
980 .Input("grad: T")
981 .Attr("T: numbertype")
982 .Attr("use_locking: bool = false")
983 .SetShapeFn(ApplyRMSPropShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
984
985REGISTER_OP("ResourceSparseApplyRMSProp")
986 .Input("var: resource")
987 .Input("ms: resource")
988 .Input("mom: resource")
989 .Input("lr: T")
990 .Input("rho: T")
991 .Input("momentum: T")
992 .Input("epsilon: T")
993 .Input("grad: T")
994 .Input("indices: Tindices")
995 .Attr("T: numbertype")
996 .Attr("Tindices: {int32, int64}")
997 .Attr("use_locking: bool = false")
998 .SetShapeFn(ApplyRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
999
1000template <bool is_sparse, bool is_resource>
1001static Status ApplyCenteredRMSPropShapeFn(InferenceContext* c) {
1002 ShapeHandle unused;
1003 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
1004 TF_RETURN_IF_ERROR(
1005 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // ms
1006 TF_RETURN_IF_ERROR(
1007 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // mg
1008 TF_RETURN_IF_ERROR(
1009 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 3), &s)); // mom
1010 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
1011 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // rho
1012 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // momentum
1013 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon
1014 TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>(
1015 c, 8 /* grad_idx */, &s));
1016 if (c->num_outputs() > 0) {
1017 c->set_output(0, s);
1018 }
1019 return OkStatus();
1020}
1021
1022REGISTER_OP("ApplyCenteredRMSProp")
1023 .Input("var: Ref(T)")
1024 .Input("mg: Ref(T)")
1025 .Input("ms: Ref(T)")
1026 .Input("mom: Ref(T)")
1027 .Input("lr: T")
1028 .Input("rho: T")
1029 .Input("momentum: T")
1030 .Input("epsilon: T")
1031 .Input("grad: T")
1032 .Output("out: Ref(T)")
1033 .Attr("T: numbertype")
1034 .Attr("use_locking: bool = false")
1035 .SetShapeFn(ApplyCenteredRMSPropShapeFn</*is_sparse=*/false,
1036 /*is_resource=*/false>);
1037
1038REGISTER_OP("SparseApplyCenteredRMSProp")
1039 .Input("var: Ref(T)")
1040 .Input("mg: Ref(T)")
1041 .Input("ms: Ref(T)")
1042 .Input("mom: Ref(T)")
1043 .Input("lr: T")
1044 .Input("rho: T")
1045 .Input("momentum: T")
1046 .Input("epsilon: T")
1047 .Input("grad: T")
1048 .Input("indices: Tindices")
1049 .Output("out: Ref(T)")
1050 .Attr("T: numbertype")
1051 .Attr("Tindices: {int32, int64}")
1052 .Attr("use_locking: bool = false")
1053 .SetShapeFn(
1054 ApplyCenteredRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/false>);
1055
1056REGISTER_OP("ResourceApplyCenteredRMSProp")
1057 .Input("var: resource")
1058 .Input("mg: resource")
1059 .Input("ms: resource")
1060 .Input("mom: resource")
1061 .Input("lr: T")
1062 .Input("rho: T")
1063 .Input("momentum: T")
1064 .Input("epsilon: T")
1065 .Input("grad: T")
1066 .Attr("T: numbertype")
1067 .Attr("use_locking: bool = false")
1068 .SetShapeFn(
1069 ApplyCenteredRMSPropShapeFn</*is_sparse=*/false, /*is_resource=*/true>);
1070
1071REGISTER_OP("ResourceSparseApplyCenteredRMSProp")
1072 .Input("var: resource")
1073 .Input("mg: resource")
1074 .Input("ms: resource")
1075 .Input("mom: resource")
1076 .Input("lr: T")
1077 .Input("rho: T")
1078 .Input("momentum: T")
1079 .Input("epsilon: T")
1080 .Input("grad: T")
1081 .Input("indices: Tindices")
1082 .Attr("T: numbertype")
1083 .Attr("Tindices: {int32, int64}")
1084 .Attr("use_locking: bool = false")
1085 .SetShapeFn(
1086 ApplyCenteredRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/true>);
1087
1088template <bool is_resource>
1089static Status ApplyAddSignShapeFn(InferenceContext* c) {
1090 ShapeHandle unused;
1091 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
1092 TF_RETURN_IF_ERROR(
1093 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m
1094 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
1095 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // alpha
1096 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_decay
1097 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta
1098 TF_RETURN_IF_ERROR(
1099 HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>(
1100 c, 6 /* grad_idx */, &s));
1101 if (c->num_outputs() > 0) {
1102 c->set_output(0, s);
1103 }
1104 return OkStatus();
1105}
1106
1107REGISTER_OP("ApplyAddSign")
1108 .Input("var: Ref(T)")
1109 .Input("m: Ref(T)")
1110 .Input("lr: T")
1111 .Input("alpha: T")
1112 .Input("sign_decay: T")
1113 .Input("beta: T")
1114 .Input("grad: T")
1115 .Output("out: Ref(T)")
1116 .Attr("T: numbertype")
1117 .Attr("use_locking: bool = false")
1118 .SetShapeFn(ApplyAddSignShapeFn</*is_resource=*/false>);
1119
1120REGISTER_OP("ResourceApplyAddSign")
1121 .Input("var: resource")
1122 .Input("m: resource")
1123 .Input("lr: T")
1124 .Input("alpha: T")
1125 .Input("sign_decay: T")
1126 .Input("beta: T")
1127 .Input("grad: T")
1128 .Attr("T: numbertype")
1129 .Attr("use_locking: bool = false")
1130 .SetShapeFn(ApplyAddSignShapeFn</*is_resource=*/true>);
1131
1132template <bool is_resource>
1133static Status ApplyPowerSignShapeFn(InferenceContext* c) {
1134 ShapeHandle unused;
1135 ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var
1136 TF_RETURN_IF_ERROR(
1137 c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m
1138 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
1139 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // logbase
1140 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_delay
1141 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta
1142 TF_RETURN_IF_ERROR(
1143 HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>(
1144 c, 6 /* grad_idx */, &s));
1145 if (c->num_outputs() > 0) {
1146 c->set_output(0, s);
1147 }
1148 return OkStatus();
1149}
1150
1151REGISTER_OP("ApplyPowerSign")
1152 .Input("var: Ref(T)")
1153 .Input("m: Ref(T)")
1154 .Input("lr: T")
1155 .Input("logbase: T")
1156 .Input("sign_decay: T")
1157 .Input("beta: T")
1158 .Input("grad: T")
1159 .Output("out: Ref(T)")
1160 .Attr("T: numbertype")
1161 .Attr("use_locking: bool = false")
1162 .SetShapeFn(ApplyPowerSignShapeFn</*is_resource=*/false>);
1163
1164REGISTER_OP("ResourceApplyPowerSign")
1165 .Input("var: resource")
1166 .Input("m: resource")
1167 .Input("lr: T")
1168 .Input("logbase: T")
1169 .Input("sign_decay: T")
1170 .Input("beta: T")
1171 .Input("grad: T")
1172 .Attr("T: numbertype")
1173 .Attr("use_locking: bool = false")
1174 .SetShapeFn(ApplyPowerSignShapeFn</*is_resource=*/true>);
1175
1176} // namespace tensorflow
1177