1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::DimensionHandle; |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | template <bool is_resource> |
26 | ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) { |
27 | auto* handle_data = c->input_handle_shapes_and_types(input); |
28 | if (handle_data != nullptr && !handle_data->empty() && |
29 | (*handle_data)[0].dtype != DT_INVALID) { |
30 | return (*handle_data)[0].shape; |
31 | } |
32 | return c->input(input); |
33 | } |
34 | |
35 | template <> |
36 | ShapeHandle ShapeOrHandleShape<true>(InferenceContext* c, int input) { |
37 | auto* handle_data = c->input_handle_shapes_and_types(input); |
38 | if (handle_data != nullptr && !handle_data->empty() && |
39 | (*handle_data)[0].dtype != DT_INVALID) { |
40 | return (*handle_data)[0].shape; |
41 | } |
42 | // If a resource input is missing shape information, we should return |
43 | // UnknownShape rather than the shape of the input, which is a scalar |
44 | // resource handle. |
45 | return c->UnknownShape(); |
46 | } |
47 | |
48 | // Handle the gradient and, if <is_sparse>, indices inputs. |
49 | // <s> is an input+output parameter, containing the current known input shape to |
50 | // the gradient. |
51 | template <bool is_sparse, bool is_resource> |
52 | static Status HandleGradAndIndicesInputs(InferenceContext* c, int grad_idx, |
53 | ShapeHandle* s) { |
54 | ShapeHandle grad = ShapeOrHandleShape<is_resource>(c, grad_idx); |
55 | if (!is_sparse) { |
56 | TF_RETURN_IF_ERROR(c->Merge(*s, grad, s)); |
57 | return OkStatus(); |
58 | } |
59 | // Indices is a vector where indices.dim[0].rank == grad[0].rank. |
60 | ShapeHandle indices; |
61 | TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices)); |
62 | DimensionHandle unused; |
63 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused)); |
64 | // Trailing part of grad matches trailing part of *s. |
65 | ShapeHandle grad_unknown_first; |
66 | TF_RETURN_IF_ERROR( |
67 | c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first)); |
68 | TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s)); |
69 | |
70 | return OkStatus(); |
71 | } |
72 | |
73 | template <bool is_resource> |
74 | static Status ApplyGradientDescentShapeFn(InferenceContext* c) { |
75 | ShapeHandle unused; |
76 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
77 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha |
78 | TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // delta |
79 | if (c->num_outputs() > 0) { |
80 | c->set_output(0, s); |
81 | } |
82 | return OkStatus(); |
83 | } |
84 | |
85 | REGISTER_OP("ApplyGradientDescent" ) |
86 | .Input("var: Ref(T)" ) |
87 | .Input("alpha: T" ) |
88 | .Input("delta: T" ) |
89 | .Output("out: Ref(T)" ) |
90 | .Attr("T: numbertype" ) |
91 | .Attr("use_locking: bool = false" ) |
92 | .SetShapeFn(ApplyGradientDescentShapeFn<false>); |
93 | |
94 | REGISTER_OP("ResourceApplyGradientDescent" ) |
95 | .Input("var: resource" ) |
96 | .Input("alpha: T" ) |
97 | .Input("delta: T" ) |
98 | .Attr("T: numbertype" ) |
99 | .Attr("use_locking: bool = false" ) |
100 | .SetShapeFn(ApplyGradientDescentShapeFn<true>); |
101 | |
102 | template <bool is_sparse, bool is_resource> |
103 | Status ApplyProximalGradientDescentShapeFn(InferenceContext* c) { |
104 | ShapeHandle unused; |
105 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
106 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha |
107 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // l1 |
108 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l2 |
109 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
110 | c, 4 /* grad_idx */, &s)); |
111 | if (c->num_outputs() > 0) { |
112 | c->set_output(0, s); |
113 | } |
114 | return OkStatus(); |
115 | } |
116 | |
117 | REGISTER_OP("ApplyProximalGradientDescent" ) |
118 | .Input("var: Ref(T)" ) |
119 | .Input("alpha: T" ) |
120 | .Input("l1: T" ) |
121 | .Input("l2: T" ) |
122 | .Input("delta: T" ) |
123 | .Output("out: Ref(T)" ) |
124 | .Attr("T: numbertype" ) |
125 | .Attr("use_locking: bool = false" ) |
126 | .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/false, |
127 | /*is_resource=*/false>); |
128 | |
129 | REGISTER_OP("SparseApplyProximalGradientDescent" ) |
130 | .Input("var: Ref(T)" ) |
131 | .Input("alpha: T" ) |
132 | .Input("l1: T" ) |
133 | .Input("l2: T" ) |
134 | .Input("grad: T" ) |
135 | .Input("indices: Tindices" ) |
136 | .Output("out: Ref(T)" ) |
137 | .Attr("T: numbertype" ) |
138 | .Attr("Tindices: {int32, int64}" ) |
139 | .Attr("use_locking: bool = false" ) |
140 | .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/true, |
141 | /*is_resource=*/false>); |
142 | |
143 | REGISTER_OP("ResourceApplyProximalGradientDescent" ) |
144 | .Input("var: resource" ) |
145 | .Input("alpha: T" ) |
146 | .Input("l1: T" ) |
147 | .Input("l2: T" ) |
148 | .Input("delta: T" ) |
149 | .Attr("T: numbertype" ) |
150 | .Attr("use_locking: bool = false" ) |
151 | .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/false, |
152 | /*is_resource=*/true>); |
153 | |
154 | REGISTER_OP("ResourceSparseApplyProximalGradientDescent" ) |
155 | .Input("var: resource" ) |
156 | .Input("alpha: T" ) |
157 | .Input("l1: T" ) |
158 | .Input("l2: T" ) |
159 | .Input("grad: T" ) |
160 | .Input("indices: Tindices" ) |
161 | .Attr("T: numbertype" ) |
162 | .Attr("Tindices: {int32, int64}" ) |
163 | .Attr("use_locking: bool = false" ) |
164 | .SetShapeFn(ApplyProximalGradientDescentShapeFn</*is_sparse=*/true, |
165 | /*is_resource=*/true>); |
166 | |
167 | template <bool is_sparse, bool is_resource> |
168 | static Status ApplyAdadeltaShapeFn(InferenceContext* c) { |
169 | ShapeHandle unused; |
170 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
171 | TF_RETURN_IF_ERROR( |
172 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum |
173 | TF_RETURN_IF_ERROR( |
174 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // accum update |
175 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr |
176 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho |
177 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // epsilon |
178 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
179 | c, 6 /* grad_idx */, &s)); |
180 | if (c->num_outputs() > 0) { |
181 | c->set_output(0, s); |
182 | } |
183 | return OkStatus(); |
184 | } |
185 | |
186 | REGISTER_OP("ApplyAdadelta" ) |
187 | .Input("var: Ref(T)" ) |
188 | .Input("accum: Ref(T)" ) |
189 | .Input("accum_update: Ref(T)" ) |
190 | .Input("lr: T" ) |
191 | .Input("rho: T" ) |
192 | .Input("epsilon: T" ) |
193 | .Input("grad: T" ) |
194 | .Output("out: Ref(T)" ) |
195 | .Attr("T: numbertype" ) |
196 | .Attr("use_locking: bool = false" ) |
197 | .SetShapeFn( |
198 | ApplyAdadeltaShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
199 | |
200 | REGISTER_OP("SparseApplyAdadelta" ) |
201 | .Input("var: Ref(T)" ) |
202 | .Input("accum: Ref(T)" ) |
203 | .Input("accum_update: Ref(T)" ) |
204 | .Input("lr: T" ) |
205 | .Input("rho: T" ) |
206 | .Input("epsilon: T" ) |
207 | .Input("grad: T" ) |
208 | .Input("indices: Tindices" ) |
209 | .Output("out: Ref(T)" ) |
210 | .Attr("T: numbertype" ) |
211 | .Attr("Tindices: {int32, int64}" ) |
212 | .Attr("use_locking: bool = false" ) |
213 | .SetShapeFn( |
214 | ApplyAdadeltaShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
215 | |
216 | REGISTER_OP("ResourceApplyAdadelta" ) |
217 | .Input("var: resource" ) |
218 | .Input("accum: resource" ) |
219 | .Input("accum_update: resource" ) |
220 | .Input("lr: T" ) |
221 | .Input("rho: T" ) |
222 | .Input("epsilon: T" ) |
223 | .Input("grad: T" ) |
224 | .Attr("T: numbertype" ) |
225 | .Attr("use_locking: bool = false" ) |
226 | .SetShapeFn( |
227 | ApplyAdadeltaShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
228 | |
229 | REGISTER_OP("ResourceSparseApplyAdadelta" ) |
230 | .Input("var: resource" ) |
231 | .Input("accum: resource" ) |
232 | .Input("accum_update: resource" ) |
233 | .Input("lr: T" ) |
234 | .Input("rho: T" ) |
235 | .Input("epsilon: T" ) |
236 | .Input("grad: T" ) |
237 | .Input("indices: Tindices" ) |
238 | .Attr("T: numbertype" ) |
239 | .Attr("Tindices: {int32, int64}" ) |
240 | .Attr("use_locking: bool = false" ) |
241 | .SetShapeFn(ApplyAdadeltaShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
242 | |
243 | template <bool is_sparse, bool is_resource> |
244 | static Status ApplyAdagradShapeFn(InferenceContext* c) { |
245 | ShapeHandle unused; |
246 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
247 | TF_RETURN_IF_ERROR( |
248 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum |
249 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr |
250 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
251 | c, 3 /* grad_idx */, &s)); |
252 | if (c->num_outputs() > 0) { |
253 | c->set_output(0, s); |
254 | } |
255 | return OkStatus(); |
256 | } |
257 | |
258 | REGISTER_OP("ApplyAdagrad" ) |
259 | .Input("var: Ref(T)" ) |
260 | .Input("accum: Ref(T)" ) |
261 | .Input("lr: T" ) |
262 | .Input("grad: T" ) |
263 | .Output("out: Ref(T)" ) |
264 | .Attr("T: numbertype" ) |
265 | .Attr("use_locking: bool = false" ) |
266 | .Attr("update_slots: bool = true" ) |
267 | .SetShapeFn( |
268 | ApplyAdagradShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
269 | |
270 | REGISTER_OP("ResourceApplyAdagrad" ) |
271 | .Input("var: resource" ) |
272 | .Input("accum: resource" ) |
273 | .Input("lr: T" ) |
274 | .Input("grad: T" ) |
275 | .Attr("T: numbertype" ) |
276 | .Attr("use_locking: bool = false" ) |
277 | .Attr("update_slots: bool = true" ) |
278 | .SetShapeFn(ApplyAdagradShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
279 | |
280 | REGISTER_OP("SparseApplyAdagrad" ) |
281 | .Input("var: Ref(T)" ) |
282 | .Input("accum: Ref(T)" ) |
283 | .Input("lr: T" ) |
284 | .Input("grad: T" ) |
285 | .Input("indices: Tindices" ) |
286 | .Output("out: Ref(T)" ) |
287 | .Attr("T: numbertype" ) |
288 | .Attr("Tindices: {int32, int64}" ) |
289 | .Attr("use_locking: bool = false" ) |
290 | .Attr("update_slots: bool = true" ) |
291 | .SetShapeFn(ApplyAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
292 | |
293 | REGISTER_OP("ResourceSparseApplyAdagrad" ) |
294 | .Input("var: resource" ) |
295 | .Input("accum: resource" ) |
296 | .Input("lr: T" ) |
297 | .Input("grad: T" ) |
298 | .Input("indices: Tindices" ) |
299 | .Attr("T: numbertype" ) |
300 | .Attr("Tindices: {int32, int64}" ) |
301 | .Attr("use_locking: bool = false" ) |
302 | .Attr("update_slots: bool = true" ) |
303 | .SetShapeFn(ApplyAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
304 | |
305 | template <bool is_sparse, bool is_resource> |
306 | static Status ApplyAdagradV2ShapeFn(InferenceContext* c) { |
307 | ShapeHandle unused; |
308 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
309 | TF_RETURN_IF_ERROR( |
310 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum |
311 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr |
312 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // epsilon |
313 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
314 | c, 4 /* grad_idx */, &s)); |
315 | if (c->num_outputs() > 0) { |
316 | c->set_output(0, s); |
317 | } |
318 | return OkStatus(); |
319 | } |
320 | |
321 | REGISTER_OP("ApplyAdagradV2" ) |
322 | .Input("var: Ref(T)" ) |
323 | .Input("accum: Ref(T)" ) |
324 | .Input("lr: T" ) |
325 | .Input("epsilon: T" ) |
326 | .Input("grad: T" ) |
327 | .Output("out: Ref(T)" ) |
328 | .Attr("T: numbertype" ) |
329 | .Attr("use_locking: bool = false" ) |
330 | .Attr("update_slots: bool = true" ) |
331 | .SetShapeFn( |
332 | ApplyAdagradV2ShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
333 | |
334 | REGISTER_OP("ResourceApplyAdagradV2" ) |
335 | .Input("var: resource" ) |
336 | .Input("accum: resource" ) |
337 | .Input("lr: T" ) |
338 | .Input("epsilon: T" ) |
339 | .Input("grad: T" ) |
340 | .Attr("T: numbertype" ) |
341 | .Attr("use_locking: bool = false" ) |
342 | .Attr("update_slots: bool = true" ) |
343 | .SetShapeFn( |
344 | ApplyAdagradV2ShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
345 | |
346 | REGISTER_OP("SparseApplyAdagradV2" ) |
347 | .Input("var: Ref(T)" ) |
348 | .Input("accum: Ref(T)" ) |
349 | .Input("lr: T" ) |
350 | .Input("epsilon: T" ) |
351 | .Input("grad: T" ) |
352 | .Input("indices: Tindices" ) |
353 | .Output("out: Ref(T)" ) |
354 | .Attr("T: numbertype" ) |
355 | .Attr("Tindices: {int32, int64}" ) |
356 | .Attr("use_locking: bool = false" ) |
357 | .Attr("update_slots: bool = true" ) |
358 | .SetShapeFn( |
359 | ApplyAdagradV2ShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
360 | |
361 | REGISTER_OP("ResourceSparseApplyAdagradV2" ) |
362 | .Input("var: resource" ) |
363 | .Input("accum: resource" ) |
364 | .Input("lr: T" ) |
365 | .Input("epsilon: T" ) |
366 | .Input("grad: T" ) |
367 | .Input("indices: Tindices" ) |
368 | .Attr("T: numbertype" ) |
369 | .Attr("Tindices: {int32, int64}" ) |
370 | .Attr("use_locking: bool = false" ) |
371 | .Attr("update_slots: bool = true" ) |
372 | .SetShapeFn( |
373 | ApplyAdagradV2ShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
374 | |
375 | template <bool is_sparse, bool is_resource> |
376 | static Status ApplyProximalAdagradShapeFn(InferenceContext* c) { |
377 | ShapeHandle unused; |
378 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
379 | TF_RETURN_IF_ERROR( |
380 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum |
381 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr |
382 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l1 |
383 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // l2 |
384 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
385 | c, 5 /* grad_idx */, &s)); |
386 | if (c->num_outputs() > 0) { |
387 | c->set_output(0, s); |
388 | } |
389 | return OkStatus(); |
390 | } |
391 | |
392 | REGISTER_OP("ApplyProximalAdagrad" ) |
393 | .Input("var: Ref(T)" ) |
394 | .Input("accum: Ref(T)" ) |
395 | .Input("lr: T" ) |
396 | .Input("l1: T" ) |
397 | .Input("l2: T" ) |
398 | .Input("grad: T" ) |
399 | .Output("out: Ref(T)" ) |
400 | .Attr("T: numbertype" ) |
401 | .Attr("use_locking: bool = false" ) |
402 | .SetShapeFn(ApplyProximalAdagradShapeFn</*is_sparse=*/false, |
403 | /*is_resource=*/false>); |
404 | |
405 | REGISTER_OP("ResourceApplyProximalAdagrad" ) |
406 | .Input("var: resource" ) |
407 | .Input("accum: resource" ) |
408 | .Input("lr: T" ) |
409 | .Input("l1: T" ) |
410 | .Input("l2: T" ) |
411 | .Input("grad: T" ) |
412 | .Attr("T: numbertype" ) |
413 | .Attr("use_locking: bool = false" ) |
414 | .SetShapeFn(ApplyProximalAdagradShapeFn</*is_sparse=*/false, |
415 | /*is_resource=*/false>); |
416 | |
417 | REGISTER_OP("SparseApplyProximalAdagrad" ) |
418 | .Input("var: Ref(T)" ) |
419 | .Input("accum: Ref(T)" ) |
420 | .Input("lr: T" ) |
421 | .Input("l1: T" ) |
422 | .Input("l2: T" ) |
423 | .Input("grad: T" ) |
424 | .Input("indices: Tindices" ) |
425 | .Output("out: Ref(T)" ) |
426 | .Attr("T: numbertype" ) |
427 | .Attr("Tindices: {int32, int64}" ) |
428 | .Attr("use_locking: bool = false" ) |
429 | .SetShapeFn( |
430 | ApplyProximalAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
431 | |
432 | REGISTER_OP("ResourceSparseApplyProximalAdagrad" ) |
433 | .Input("var: resource" ) |
434 | .Input("accum: resource" ) |
435 | .Input("lr: T" ) |
436 | .Input("l1: T" ) |
437 | .Input("l2: T" ) |
438 | .Input("grad: T" ) |
439 | .Input("indices: Tindices" ) |
440 | .Attr("T: numbertype" ) |
441 | .Attr("Tindices: {int32, int64}" ) |
442 | .Attr("use_locking: bool = false" ) |
443 | .SetShapeFn( |
444 | ApplyProximalAdagradShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
445 | |
446 | template <bool is_sparse, bool is_resource> |
447 | static Status ApplyAdagradDAShapeFn(InferenceContext* c) { |
448 | ShapeHandle unused; |
449 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
450 | TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), |
451 | &s)); // grad_accumulator |
452 | TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), |
453 | &s)); // gradient_squared_accumulator |
454 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
455 | c, 3 /* grad_idx */, &s)); |
456 | int idx = is_sparse ? 5 : 4; |
457 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr |
458 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1 |
459 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2 |
460 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // global step |
461 | if (c->num_outputs() > 0) { |
462 | c->set_output(0, s); |
463 | } |
464 | return OkStatus(); |
465 | } |
466 | |
467 | REGISTER_OP("ApplyAdagradDA" ) |
468 | .Input("var: Ref(T)" ) |
469 | .Input("gradient_accumulator: Ref(T)" ) |
470 | .Input("gradient_squared_accumulator: Ref(T)" ) |
471 | .Input("grad: T" ) |
472 | .Input("lr: T" ) |
473 | .Input("l1: T" ) |
474 | .Input("l2: T" ) |
475 | .Input("global_step: int64" ) |
476 | .Output("out: Ref(T)" ) |
477 | .Attr("T: numbertype" ) |
478 | .Attr("use_locking: bool = false" ) |
479 | .SetShapeFn( |
480 | ApplyAdagradDAShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
481 | |
482 | REGISTER_OP("SparseApplyAdagradDA" ) |
483 | .Input("var: Ref(T)" ) |
484 | .Input("gradient_accumulator: Ref(T)" ) |
485 | .Input("gradient_squared_accumulator: Ref(T)" ) |
486 | .Input("grad: T" ) |
487 | .Input("indices: Tindices" ) |
488 | .Input("lr: T" ) |
489 | .Input("l1: T" ) |
490 | .Input("l2: T" ) |
491 | .Input("global_step: int64" ) |
492 | .Output("out: Ref(T)" ) |
493 | .Attr("T: numbertype" ) |
494 | .Attr("Tindices: {int32, int64}" ) |
495 | .Attr("use_locking: bool = false" ) |
496 | .SetShapeFn( |
497 | ApplyAdagradDAShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
498 | |
499 | REGISTER_OP("ResourceApplyAdagradDA" ) |
500 | .Input("var: resource" ) |
501 | .Input("gradient_accumulator: resource" ) |
502 | .Input("gradient_squared_accumulator: resource" ) |
503 | .Input("grad: T" ) |
504 | .Input("lr: T" ) |
505 | .Input("l1: T" ) |
506 | .Input("l2: T" ) |
507 | .Input("global_step: int64" ) |
508 | .Attr("T: numbertype" ) |
509 | .Attr("use_locking: bool = false" ) |
510 | .SetShapeFn( |
511 | ApplyAdagradDAShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
512 | |
513 | REGISTER_OP("ResourceSparseApplyAdagradDA" ) |
514 | .Input("var: resource" ) |
515 | .Input("gradient_accumulator: resource" ) |
516 | .Input("gradient_squared_accumulator: resource" ) |
517 | .Input("grad: T" ) |
518 | .Input("indices: Tindices" ) |
519 | .Input("lr: T" ) |
520 | .Input("l1: T" ) |
521 | .Input("l2: T" ) |
522 | .Input("global_step: int64" ) |
523 | .Attr("T: numbertype" ) |
524 | .Attr("Tindices: {int32, int64}" ) |
525 | .Attr("use_locking: bool = false" ) |
526 | .SetShapeFn( |
527 | ApplyAdagradDAShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
528 | |
529 | template <bool is_sparse, bool is_resource> |
530 | static Status ApplyFtrlShapeFn(InferenceContext* c) { |
531 | ShapeHandle unused; |
532 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
533 | TF_RETURN_IF_ERROR( |
534 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum |
535 | TF_RETURN_IF_ERROR( |
536 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // linear |
537 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
538 | c, 3 /* grad_idx */, &s)); |
539 | int idx = is_sparse ? 5 : 4; |
540 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr |
541 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1 |
542 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2 |
543 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr_power |
544 | if (c->num_outputs() > 0) { |
545 | c->set_output(0, s); |
546 | } |
547 | return OkStatus(); |
548 | } |
549 | |
550 | REGISTER_OP("ApplyFtrl" ) |
551 | .Input("var: Ref(T)" ) |
552 | .Input("accum: Ref(T)" ) |
553 | .Input("linear: Ref(T)" ) |
554 | .Input("grad: T" ) |
555 | .Input("lr: T" ) |
556 | .Input("l1: T" ) |
557 | .Input("l2: T" ) |
558 | .Input("lr_power: T" ) |
559 | .Output("out: Ref(T)" ) |
560 | .Attr("T: numbertype" ) |
561 | .Attr("use_locking: bool = false" ) |
562 | .Attr("multiply_linear_by_lr: bool = false" ) |
563 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
564 | |
565 | REGISTER_OP("SparseApplyFtrl" ) |
566 | .Input("var: Ref(T)" ) |
567 | .Input("accum: Ref(T)" ) |
568 | .Input("linear: Ref(T)" ) |
569 | .Input("grad: T" ) |
570 | .Input("indices: Tindices" ) |
571 | .Input("lr: T" ) |
572 | .Input("l1: T" ) |
573 | .Input("l2: T" ) |
574 | .Input("lr_power: T" ) |
575 | .Output("out: Ref(T)" ) |
576 | .Attr("T: numbertype" ) |
577 | .Attr("Tindices: {int32, int64}" ) |
578 | .Attr("use_locking: bool = false" ) |
579 | .Attr("multiply_linear_by_lr: bool = false" ) |
580 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
581 | |
582 | REGISTER_OP("ResourceApplyFtrl" ) |
583 | .Input("var: resource" ) |
584 | .Input("accum: resource" ) |
585 | .Input("linear: resource" ) |
586 | .Input("grad: T" ) |
587 | .Input("lr: T" ) |
588 | .Input("l1: T" ) |
589 | .Input("l2: T" ) |
590 | .Input("lr_power: T" ) |
591 | .Attr("T: numbertype" ) |
592 | .Attr("use_locking: bool = false" ) |
593 | .Attr("multiply_linear_by_lr: bool = false" ) |
594 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
595 | |
596 | REGISTER_OP("ResourceSparseApplyFtrl" ) |
597 | .Input("var: resource" ) |
598 | .Input("accum: resource" ) |
599 | .Input("linear: resource" ) |
600 | .Input("grad: T" ) |
601 | .Input("indices: Tindices" ) |
602 | .Input("lr: T" ) |
603 | .Input("l1: T" ) |
604 | .Input("l2: T" ) |
605 | .Input("lr_power: T" ) |
606 | .Attr("T: numbertype" ) |
607 | .Attr("Tindices: {int32, int64}" ) |
608 | .Attr("use_locking: bool = false" ) |
609 | .Attr("multiply_linear_by_lr: bool = false" ) |
610 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
611 | |
612 | REGISTER_OP("ApplyFtrlV2" ) |
613 | .Input("var: Ref(T)" ) |
614 | .Input("accum: Ref(T)" ) |
615 | .Input("linear: Ref(T)" ) |
616 | .Input("grad: T" ) |
617 | .Input("lr: T" ) |
618 | .Input("l1: T" ) |
619 | .Input("l2: T" ) |
620 | .Input("l2_shrinkage: T" ) |
621 | .Input("lr_power: T" ) |
622 | .Output("out: Ref(T)" ) |
623 | .Attr("T: numbertype" ) |
624 | .Attr("use_locking: bool = false" ) |
625 | .Attr("multiply_linear_by_lr: bool = false" ) |
626 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
627 | |
628 | REGISTER_OP("SparseApplyFtrlV2" ) |
629 | .Input("var: Ref(T)" ) |
630 | .Input("accum: Ref(T)" ) |
631 | .Input("linear: Ref(T)" ) |
632 | .Input("grad: T" ) |
633 | .Input("indices: Tindices" ) |
634 | .Input("lr: T" ) |
635 | .Input("l1: T" ) |
636 | .Input("l2: T" ) |
637 | .Input("l2_shrinkage: T" ) |
638 | .Input("lr_power: T" ) |
639 | .Output("out: Ref(T)" ) |
640 | .Attr("T: numbertype" ) |
641 | .Attr("Tindices: {int32, int64}" ) |
642 | .Attr("use_locking: bool = false" ) |
643 | .Attr("multiply_linear_by_lr: bool = false" ) |
644 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
645 | |
646 | REGISTER_OP("ResourceApplyFtrlV2" ) |
647 | .Input("var: resource" ) |
648 | .Input("accum: resource" ) |
649 | .Input("linear: resource" ) |
650 | .Input("grad: T" ) |
651 | .Input("lr: T" ) |
652 | .Input("l1: T" ) |
653 | .Input("l2: T" ) |
654 | .Input("l2_shrinkage: T" ) |
655 | .Input("lr_power: T" ) |
656 | .Attr("T: numbertype" ) |
657 | .Attr("use_locking: bool = false" ) |
658 | .Attr("multiply_linear_by_lr: bool = false" ) |
659 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
660 | |
661 | REGISTER_OP("ResourceSparseApplyFtrlV2" ) |
662 | .Input("var: resource" ) |
663 | .Input("accum: resource" ) |
664 | .Input("linear: resource" ) |
665 | .Input("grad: T" ) |
666 | .Input("indices: Tindices" ) |
667 | .Input("lr: T" ) |
668 | .Input("l1: T" ) |
669 | .Input("l2: T" ) |
670 | .Input("l2_shrinkage: T" ) |
671 | .Input("lr_power: T" ) |
672 | .Attr("T: numbertype" ) |
673 | .Attr("Tindices: {int32, int64}" ) |
674 | .Attr("use_locking: bool = false" ) |
675 | .Attr("multiply_linear_by_lr: bool = false" ) |
676 | .SetShapeFn(ApplyFtrlShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
677 | |
678 | template <bool is_sparse, bool is_resource> |
679 | static Status ApplyMomentumShapeFn(InferenceContext* c) { |
680 | ShapeHandle unused; |
681 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
682 | TF_RETURN_IF_ERROR( |
683 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // accum |
684 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr |
685 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
686 | c, 3 /* grad_idx */, &s)); |
687 | int idx = is_sparse ? 5 : 4; |
688 | TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // momentum |
689 | if (c->num_outputs() > 0) { |
690 | c->set_output(0, s); |
691 | } |
692 | return OkStatus(); |
693 | } |
694 | |
695 | REGISTER_OP("ApplyMomentum" ) |
696 | .Input("var: Ref(T)" ) |
697 | .Input("accum: Ref(T)" ) |
698 | .Input("lr: T" ) |
699 | .Input("grad: T" ) |
700 | .Input("momentum: T" ) |
701 | .Output("out: Ref(T)" ) |
702 | .Attr("T: numbertype" ) |
703 | .Attr("use_locking: bool = false" ) |
704 | .Attr("use_nesterov: bool = false" ) |
705 | .SetShapeFn( |
706 | ApplyMomentumShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
707 | |
708 | REGISTER_OP("SparseApplyMomentum" ) |
709 | .Input("var: Ref(T)" ) |
710 | .Input("accum: Ref(T)" ) |
711 | .Input("lr: T" ) |
712 | .Input("grad: T" ) |
713 | .Input("indices: Tindices" ) |
714 | .Input("momentum: T" ) |
715 | .Output("out: Ref(T)" ) |
716 | .Attr("T: numbertype" ) |
717 | .Attr("Tindices: {int32, int64}" ) |
718 | .Attr("use_locking: bool = false" ) |
719 | .Attr("use_nesterov: bool = false" ) |
720 | .SetShapeFn( |
721 | ApplyMomentumShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
722 | |
723 | REGISTER_OP("ResourceApplyMomentum" ) |
724 | .Input("var: resource" ) |
725 | .Input("accum: resource" ) |
726 | .Input("lr: T" ) |
727 | .Input("grad: T" ) |
728 | .Input("momentum: T" ) |
729 | .Attr("T: numbertype" ) |
730 | .Attr("use_locking: bool = false" ) |
731 | .Attr("use_nesterov: bool = false" ) |
732 | .SetShapeFn( |
733 | ApplyMomentumShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
734 | |
735 | REGISTER_OP("ResourceSparseApplyMomentum" ) |
736 | .Input("var: resource" ) |
737 | .Input("accum: resource" ) |
738 | .Input("lr: T" ) |
739 | .Input("grad: T" ) |
740 | .Input("indices: Tindices" ) |
741 | .Input("momentum: T" ) |
742 | .Attr("T: numbertype" ) |
743 | .Attr("Tindices: {int32, int64}" ) |
744 | .Attr("use_locking: bool = false" ) |
745 | .Attr("use_nesterov: bool = false" ) |
746 | .SetShapeFn(ApplyMomentumShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
747 | |
748 | REGISTER_OP("ResourceApplyKerasMomentum" ) |
749 | .Input("var: resource" ) |
750 | .Input("accum: resource" ) |
751 | .Input("lr: T" ) |
752 | .Input("grad: T" ) |
753 | .Input("momentum: T" ) |
754 | .Attr("T: numbertype" ) |
755 | .Attr("use_locking: bool = false" ) |
756 | .Attr("use_nesterov: bool = false" ) |
757 | .SetShapeFn( |
758 | ApplyMomentumShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
759 | |
760 | REGISTER_OP("ResourceSparseApplyKerasMomentum" ) |
761 | .Input("var: resource" ) |
762 | .Input("accum: resource" ) |
763 | .Input("lr: T" ) |
764 | .Input("grad: T" ) |
765 | .Input("indices: Tindices" ) |
766 | .Input("momentum: T" ) |
767 | .Attr("T: numbertype" ) |
768 | .Attr("Tindices: {int32, int64}" ) |
769 | .Attr("use_locking: bool = false" ) |
770 | .Attr("use_nesterov: bool = false" ) |
771 | .SetShapeFn(ApplyMomentumShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
772 | |
773 | template <bool is_resource> |
774 | static Status ApplyAdamShapeFn(InferenceContext* c) { |
775 | ShapeHandle unused; |
776 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
777 | TF_RETURN_IF_ERROR( |
778 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m |
779 | TF_RETURN_IF_ERROR( |
780 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // v |
781 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power |
782 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta2_power |
783 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // lr |
784 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta1 |
785 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta2 |
786 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // epsilon |
787 | TF_RETURN_IF_ERROR( |
788 | HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>( |
789 | c, 9 /* grad_idx */, &s)); |
790 | if (c->num_outputs() > 0) { |
791 | c->set_output(0, s); |
792 | } |
793 | return OkStatus(); |
794 | } |
795 | |
796 | REGISTER_OP("ApplyAdam" ) |
797 | .Input("var: Ref(T)" ) |
798 | .Input("m: Ref(T)" ) |
799 | .Input("v: Ref(T)" ) |
800 | .Input("beta1_power: T" ) |
801 | .Input("beta2_power: T" ) |
802 | .Input("lr: T" ) |
803 | .Input("beta1: T" ) |
804 | .Input("beta2: T" ) |
805 | .Input("epsilon: T" ) |
806 | .Input("grad: T" ) |
807 | .Output("out: Ref(T)" ) |
808 | .Attr("T: numbertype" ) |
809 | .Attr("use_locking: bool = false" ) |
810 | .Attr("use_nesterov: bool = false" ) |
811 | .SetShapeFn(ApplyAdamShapeFn</*is_resource=*/false>); |
812 | |
813 | REGISTER_OP("ResourceApplyAdam" ) |
814 | .Input("var: resource" ) |
815 | .Input("m: resource" ) |
816 | .Input("v: resource" ) |
817 | .Input("beta1_power: T" ) |
818 | .Input("beta2_power: T" ) |
819 | .Input("lr: T" ) |
820 | .Input("beta1: T" ) |
821 | .Input("beta2: T" ) |
822 | .Input("epsilon: T" ) |
823 | .Input("grad: T" ) |
824 | .Attr("T: numbertype" ) |
825 | .Attr("use_locking: bool = false" ) |
826 | .Attr("use_nesterov: bool = false" ) |
827 | .SetShapeFn(ApplyAdamShapeFn</*is_resource=*/true>); |
828 | |
829 | template <bool is_resource> |
830 | static Status ApplyAdamWithAmsgradShapeFn(InferenceContext* c) { |
831 | ShapeHandle unused; |
832 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
833 | TF_RETURN_IF_ERROR( |
834 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m |
835 | TF_RETURN_IF_ERROR( |
836 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // v |
837 | TF_RETURN_IF_ERROR( |
838 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 3), &s)); // vhat |
839 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta1_power |
840 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta2_power |
841 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // lr |
842 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta1 |
843 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // beta2 |
844 | TF_RETURN_IF_ERROR(c->WithRank(c->input(9), 0, &unused)); // epsilon |
845 | TF_RETURN_IF_ERROR( |
846 | HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>( |
847 | c, 10 /* grad_idx */, &s)); |
848 | if (c->num_outputs() > 0) { |
849 | c->set_output(0, s); |
850 | } |
851 | return OkStatus(); |
852 | } |
853 | |
854 | REGISTER_OP("ResourceApplyAdamWithAmsgrad" ) |
855 | .Input("var: resource" ) |
856 | .Input("m: resource" ) |
857 | .Input("v: resource" ) |
858 | .Input("vhat: resource" ) |
859 | .Input("beta1_power: T" ) |
860 | .Input("beta2_power: T" ) |
861 | .Input("lr: T" ) |
862 | .Input("beta1: T" ) |
863 | .Input("beta2: T" ) |
864 | .Input("epsilon: T" ) |
865 | .Input("grad: T" ) |
866 | .Attr("T: numbertype" ) |
867 | .Attr("use_locking: bool = false" ) |
868 | .SetShapeFn(ApplyAdamWithAmsgradShapeFn</*is_resource=*/true>); |
869 | |
870 | template <bool is_resource> |
871 | static Status ApplyAdaMaxShapeFn(InferenceContext* c) { |
872 | ShapeHandle unused; |
873 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
874 | TF_RETURN_IF_ERROR( |
875 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m |
876 | TF_RETURN_IF_ERROR( |
877 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // v |
878 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power |
879 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr |
880 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta1 |
881 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta2 |
882 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon |
883 | TF_RETURN_IF_ERROR( |
884 | HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>( |
885 | c, 8 /* grad_idx */, &s)); |
886 | if (c->num_outputs() > 0) { |
887 | c->set_output(0, s); |
888 | } |
889 | return OkStatus(); |
890 | } |
891 | |
892 | REGISTER_OP("ApplyAdaMax" ) |
893 | .Input("var: Ref(T)" ) |
894 | .Input("m: Ref(T)" ) |
895 | .Input("v: Ref(T)" ) |
896 | .Input("beta1_power: T" ) |
897 | .Input("lr: T" ) |
898 | .Input("beta1: T" ) |
899 | .Input("beta2: T" ) |
900 | .Input("epsilon: T" ) |
901 | .Input("grad: T" ) |
902 | .Output("out: Ref(T)" ) |
903 | .Attr("T: numbertype" ) |
904 | .Attr("use_locking: bool = false" ) |
905 | .SetShapeFn(ApplyAdaMaxShapeFn</*is_resource=*/false>); |
906 | |
907 | REGISTER_OP("ResourceApplyAdaMax" ) |
908 | .Input("var: resource" ) |
909 | .Input("m: resource" ) |
910 | .Input("v: resource" ) |
911 | .Input("beta1_power: T" ) |
912 | .Input("lr: T" ) |
913 | .Input("beta1: T" ) |
914 | .Input("beta2: T" ) |
915 | .Input("epsilon: T" ) |
916 | .Input("grad: T" ) |
917 | .Attr("T: numbertype" ) |
918 | .Attr("use_locking: bool = false" ) |
919 | .SetShapeFn(ApplyAdaMaxShapeFn</*is_resource=*/true>); |
920 | |
921 | template <bool is_sparse, bool is_resource> |
922 | static Status ApplyRMSPropShapeFn(InferenceContext* c) { |
923 | ShapeHandle unused; |
924 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
925 | TF_RETURN_IF_ERROR( |
926 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // ms |
927 | TF_RETURN_IF_ERROR( |
928 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // mom |
929 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr |
930 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho |
931 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // momentum |
932 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // epsilon |
933 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
934 | c, 7 /* grad_idx */, &s)); |
935 | if (c->num_outputs() > 0) { |
936 | c->set_output(0, s); |
937 | } |
938 | return OkStatus(); |
939 | } |
940 | |
941 | REGISTER_OP("ApplyRMSProp" ) |
942 | .Input("var: Ref(T)" ) |
943 | .Input("ms: Ref(T)" ) |
944 | .Input("mom: Ref(T)" ) |
945 | .Input("lr: T" ) |
946 | .Input("rho: T" ) |
947 | .Input("momentum: T" ) |
948 | .Input("epsilon: T" ) |
949 | .Input("grad: T" ) |
950 | .Output("out: Ref(T)" ) |
951 | .Attr("T: numbertype" ) |
952 | .Attr("use_locking: bool = false" ) |
953 | .SetShapeFn( |
954 | ApplyRMSPropShapeFn</*is_sparse=*/false, /*is_resource=*/false>); |
955 | |
956 | REGISTER_OP("SparseApplyRMSProp" ) |
957 | .Input("var: Ref(T)" ) |
958 | .Input("ms: Ref(T)" ) |
959 | .Input("mom: Ref(T)" ) |
960 | .Input("lr: T" ) |
961 | .Input("rho: T" ) |
962 | .Input("momentum: T" ) |
963 | .Input("epsilon: T" ) |
964 | .Input("grad: T" ) |
965 | .Input("indices: Tindices" ) |
966 | .Output("out: Ref(T)" ) |
967 | .Attr("T: numbertype" ) |
968 | .Attr("Tindices: {int32, int64}" ) |
969 | .Attr("use_locking: bool = false" ) |
970 | .SetShapeFn(ApplyRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
971 | |
972 | REGISTER_OP("ResourceApplyRMSProp" ) |
973 | .Input("var: resource" ) |
974 | .Input("ms: resource" ) |
975 | .Input("mom: resource" ) |
976 | .Input("lr: T" ) |
977 | .Input("rho: T" ) |
978 | .Input("momentum: T" ) |
979 | .Input("epsilon: T" ) |
980 | .Input("grad: T" ) |
981 | .Attr("T: numbertype" ) |
982 | .Attr("use_locking: bool = false" ) |
983 | .SetShapeFn(ApplyRMSPropShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
984 | |
985 | REGISTER_OP("ResourceSparseApplyRMSProp" ) |
986 | .Input("var: resource" ) |
987 | .Input("ms: resource" ) |
988 | .Input("mom: resource" ) |
989 | .Input("lr: T" ) |
990 | .Input("rho: T" ) |
991 | .Input("momentum: T" ) |
992 | .Input("epsilon: T" ) |
993 | .Input("grad: T" ) |
994 | .Input("indices: Tindices" ) |
995 | .Attr("T: numbertype" ) |
996 | .Attr("Tindices: {int32, int64}" ) |
997 | .Attr("use_locking: bool = false" ) |
998 | .SetShapeFn(ApplyRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
999 | |
1000 | template <bool is_sparse, bool is_resource> |
1001 | static Status ApplyCenteredRMSPropShapeFn(InferenceContext* c) { |
1002 | ShapeHandle unused; |
1003 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
1004 | TF_RETURN_IF_ERROR( |
1005 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // ms |
1006 | TF_RETURN_IF_ERROR( |
1007 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 2), &s)); // mg |
1008 | TF_RETURN_IF_ERROR( |
1009 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 3), &s)); // mom |
1010 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr |
1011 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // rho |
1012 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // momentum |
1013 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon |
1014 | TF_RETURN_IF_ERROR(HandleGradAndIndicesInputs<is_sparse, is_resource>( |
1015 | c, 8 /* grad_idx */, &s)); |
1016 | if (c->num_outputs() > 0) { |
1017 | c->set_output(0, s); |
1018 | } |
1019 | return OkStatus(); |
1020 | } |
1021 | |
1022 | REGISTER_OP("ApplyCenteredRMSProp" ) |
1023 | .Input("var: Ref(T)" ) |
1024 | .Input("mg: Ref(T)" ) |
1025 | .Input("ms: Ref(T)" ) |
1026 | .Input("mom: Ref(T)" ) |
1027 | .Input("lr: T" ) |
1028 | .Input("rho: T" ) |
1029 | .Input("momentum: T" ) |
1030 | .Input("epsilon: T" ) |
1031 | .Input("grad: T" ) |
1032 | .Output("out: Ref(T)" ) |
1033 | .Attr("T: numbertype" ) |
1034 | .Attr("use_locking: bool = false" ) |
1035 | .SetShapeFn(ApplyCenteredRMSPropShapeFn</*is_sparse=*/false, |
1036 | /*is_resource=*/false>); |
1037 | |
1038 | REGISTER_OP("SparseApplyCenteredRMSProp" ) |
1039 | .Input("var: Ref(T)" ) |
1040 | .Input("mg: Ref(T)" ) |
1041 | .Input("ms: Ref(T)" ) |
1042 | .Input("mom: Ref(T)" ) |
1043 | .Input("lr: T" ) |
1044 | .Input("rho: T" ) |
1045 | .Input("momentum: T" ) |
1046 | .Input("epsilon: T" ) |
1047 | .Input("grad: T" ) |
1048 | .Input("indices: Tindices" ) |
1049 | .Output("out: Ref(T)" ) |
1050 | .Attr("T: numbertype" ) |
1051 | .Attr("Tindices: {int32, int64}" ) |
1052 | .Attr("use_locking: bool = false" ) |
1053 | .SetShapeFn( |
1054 | ApplyCenteredRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/false>); |
1055 | |
1056 | REGISTER_OP("ResourceApplyCenteredRMSProp" ) |
1057 | .Input("var: resource" ) |
1058 | .Input("mg: resource" ) |
1059 | .Input("ms: resource" ) |
1060 | .Input("mom: resource" ) |
1061 | .Input("lr: T" ) |
1062 | .Input("rho: T" ) |
1063 | .Input("momentum: T" ) |
1064 | .Input("epsilon: T" ) |
1065 | .Input("grad: T" ) |
1066 | .Attr("T: numbertype" ) |
1067 | .Attr("use_locking: bool = false" ) |
1068 | .SetShapeFn( |
1069 | ApplyCenteredRMSPropShapeFn</*is_sparse=*/false, /*is_resource=*/true>); |
1070 | |
1071 | REGISTER_OP("ResourceSparseApplyCenteredRMSProp" ) |
1072 | .Input("var: resource" ) |
1073 | .Input("mg: resource" ) |
1074 | .Input("ms: resource" ) |
1075 | .Input("mom: resource" ) |
1076 | .Input("lr: T" ) |
1077 | .Input("rho: T" ) |
1078 | .Input("momentum: T" ) |
1079 | .Input("epsilon: T" ) |
1080 | .Input("grad: T" ) |
1081 | .Input("indices: Tindices" ) |
1082 | .Attr("T: numbertype" ) |
1083 | .Attr("Tindices: {int32, int64}" ) |
1084 | .Attr("use_locking: bool = false" ) |
1085 | .SetShapeFn( |
1086 | ApplyCenteredRMSPropShapeFn</*is_sparse=*/true, /*is_resource=*/true>); |
1087 | |
1088 | template <bool is_resource> |
1089 | static Status ApplyAddSignShapeFn(InferenceContext* c) { |
1090 | ShapeHandle unused; |
1091 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
1092 | TF_RETURN_IF_ERROR( |
1093 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m |
1094 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr |
1095 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // alpha |
1096 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_decay |
1097 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta |
1098 | TF_RETURN_IF_ERROR( |
1099 | HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>( |
1100 | c, 6 /* grad_idx */, &s)); |
1101 | if (c->num_outputs() > 0) { |
1102 | c->set_output(0, s); |
1103 | } |
1104 | return OkStatus(); |
1105 | } |
1106 | |
1107 | REGISTER_OP("ApplyAddSign" ) |
1108 | .Input("var: Ref(T)" ) |
1109 | .Input("m: Ref(T)" ) |
1110 | .Input("lr: T" ) |
1111 | .Input("alpha: T" ) |
1112 | .Input("sign_decay: T" ) |
1113 | .Input("beta: T" ) |
1114 | .Input("grad: T" ) |
1115 | .Output("out: Ref(T)" ) |
1116 | .Attr("T: numbertype" ) |
1117 | .Attr("use_locking: bool = false" ) |
1118 | .SetShapeFn(ApplyAddSignShapeFn</*is_resource=*/false>); |
1119 | |
1120 | REGISTER_OP("ResourceApplyAddSign" ) |
1121 | .Input("var: resource" ) |
1122 | .Input("m: resource" ) |
1123 | .Input("lr: T" ) |
1124 | .Input("alpha: T" ) |
1125 | .Input("sign_decay: T" ) |
1126 | .Input("beta: T" ) |
1127 | .Input("grad: T" ) |
1128 | .Attr("T: numbertype" ) |
1129 | .Attr("use_locking: bool = false" ) |
1130 | .SetShapeFn(ApplyAddSignShapeFn</*is_resource=*/true>); |
1131 | |
1132 | template <bool is_resource> |
1133 | static Status ApplyPowerSignShapeFn(InferenceContext* c) { |
1134 | ShapeHandle unused; |
1135 | ShapeHandle s = ShapeOrHandleShape<is_resource>(c, 0); // var |
1136 | TF_RETURN_IF_ERROR( |
1137 | c->Merge(s, ShapeOrHandleShape<is_resource>(c, 1), &s)); // m |
1138 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr |
1139 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // logbase |
1140 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_delay |
1141 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta |
1142 | TF_RETURN_IF_ERROR( |
1143 | HandleGradAndIndicesInputs</*is_sparse=*/false, is_resource>( |
1144 | c, 6 /* grad_idx */, &s)); |
1145 | if (c->num_outputs() > 0) { |
1146 | c->set_output(0, s); |
1147 | } |
1148 | return OkStatus(); |
1149 | } |
1150 | |
1151 | REGISTER_OP("ApplyPowerSign" ) |
1152 | .Input("var: Ref(T)" ) |
1153 | .Input("m: Ref(T)" ) |
1154 | .Input("lr: T" ) |
1155 | .Input("logbase: T" ) |
1156 | .Input("sign_decay: T" ) |
1157 | .Input("beta: T" ) |
1158 | .Input("grad: T" ) |
1159 | .Output("out: Ref(T)" ) |
1160 | .Attr("T: numbertype" ) |
1161 | .Attr("use_locking: bool = false" ) |
1162 | .SetShapeFn(ApplyPowerSignShapeFn</*is_resource=*/false>); |
1163 | |
1164 | REGISTER_OP("ResourceApplyPowerSign" ) |
1165 | .Input("var: resource" ) |
1166 | .Input("m: resource" ) |
1167 | .Input("lr: T" ) |
1168 | .Input("logbase: T" ) |
1169 | .Input("sign_decay: T" ) |
1170 | .Input("beta: T" ) |
1171 | .Input("grad: T" ) |
1172 | .Attr("T: numbertype" ) |
1173 | .Attr("use_locking: bool = false" ) |
1174 | .SetShapeFn(ApplyPowerSignShapeFn</*is_resource=*/true>); |
1175 | |
1176 | } // namespace tensorflow |
1177 | |