1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
17#define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/tensor_types.h"
21#include "tensorflow/core/lib/core/status.h"
22#include "tensorflow/core/platform/types.h"
23
24namespace tensorflow {
25namespace functor {
26
27// Each training algorithm has a ApplyXYZ functor struct declared in
28// this header file. They are specialized for different devices
29// (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc).
30
31template <typename Device, typename T>
32struct ApplyGradientDescent {
33 void operator()(const Device& d, typename TTypes<T>::Flat var,
34 typename TTypes<T>::ConstScalar alpha,
35 typename TTypes<T>::ConstFlat delta);
36};
37
38template <typename Device, typename T>
39struct ApplyAdadelta {
40 void operator()(const Device& d, typename TTypes<T>::Flat var,
41 typename TTypes<T>::Flat accum,
42 typename TTypes<T>::Flat accum_update,
43 typename TTypes<T>::ConstScalar lr,
44 typename TTypes<T>::ConstScalar rho,
45 typename TTypes<T>::ConstScalar epsilon,
46 typename TTypes<T>::ConstFlat grad);
47};
48
49template <typename Device, typename T, typename Tindex>
50struct SparseApplyAdadelta {
51 void operator()(const Device& d, typename TTypes<T>::Matrix var,
52 typename TTypes<T>::Matrix accum,
53 typename TTypes<T>::Matrix accum_update,
54 typename TTypes<T>::ConstScalar lr,
55 typename TTypes<T>::ConstScalar rho,
56 typename TTypes<T>::ConstScalar epsilon,
57 typename TTypes<T>::ConstMatrix grad,
58 typename TTypes<Tindex>::ConstFlat indices);
59};
60
61template <typename Device, typename T>
62struct FobosElasticNet {
63 void operator()(const Device& d, typename TTypes<T>::Flat var,
64 typename TTypes<T>::ConstScalar lr,
65 typename TTypes<T>::ConstScalar l1,
66 typename TTypes<T>::ConstScalar l2,
67 typename TTypes<T>::ConstFlat grad);
68};
69
70template <typename Device, typename T>
71struct ApplyProximalGradientDescent {
72 void operator()(const Device& d, typename TTypes<T>::Flat var,
73 typename TTypes<T>::ConstScalar lr,
74 typename TTypes<T>::ConstScalar l1,
75 typename TTypes<T>::ConstScalar l2,
76 typename TTypes<T>::ConstFlat grad);
77};
78
79template <typename Device, typename T>
80struct ApplyAdagrad {
81 void operator()(const Device& d, typename TTypes<T>::Flat var,
82 typename TTypes<T>::Flat accum,
83 typename TTypes<T>::ConstScalar lr,
84 typename TTypes<T>::ConstFlat grad, bool update_slots);
85};
86
87template <typename Device, typename T>
88struct ApplyAdagradV2 {
89 void operator()(const Device& d, typename TTypes<T>::Flat var,
90 typename TTypes<T>::Flat accum,
91 typename TTypes<T>::ConstScalar lr,
92 typename TTypes<T>::ConstScalar epsilon,
93 typename TTypes<T>::ConstFlat grad, bool update_slots);
94};
95
96template <typename Device, typename T>
97struct ApplyAdagradDA {
98 void operator()(const Device& d, typename TTypes<T>::Flat var,
99 typename TTypes<T>::Flat gradient_accum,
100 typename TTypes<T>::Flat gradient_squared_accum,
101 typename TTypes<T>::ConstScalar lr, int64_t global_step,
102 typename TTypes<T>::ConstScalar l1,
103 typename TTypes<T>::ConstScalar l2,
104 typename TTypes<T>::ConstFlat grad);
105};
106
107template <typename Device, typename T, typename Tindex, bool has_epsilon>
108struct SparseApplyAdagrad {
109 // Note that epsilon is ignored if has_epsilon is false.
110 Status operator()(const Device& d, typename TTypes<T>::Matrix var,
111 typename TTypes<T>::Matrix accum,
112 typename TTypes<T>::ConstScalar lr,
113 typename TTypes<T>::ConstScalar epsilon,
114 typename TTypes<T>::ConstMatrix grad,
115 typename TTypes<Tindex>::ConstVec indices,
116 int64_t inner_dim, bool update_slots);
117};
118
119template <typename Device, typename T>
120struct ApplyProximalAdagrad {
121 void operator()(const Device& d, typename TTypes<T>::Flat var,
122 typename TTypes<T>::Flat accum,
123 typename TTypes<T>::ConstScalar lr,
124 typename TTypes<T>::ConstScalar l1,
125 typename TTypes<T>::ConstScalar l2,
126 typename TTypes<T>::ConstFlat grad);
127};
128
129template <typename Device, typename T, typename Tindex>
130struct SparseApplyProximalAdagrad {
131 Status operator()(const Device& d, typename TTypes<T>::Matrix var,
132 typename TTypes<T>::Matrix accum,
133 typename TTypes<T>::ConstScalar lr,
134 typename TTypes<T>::ConstScalar l1,
135 typename TTypes<T>::ConstScalar l2,
136 typename TTypes<T>::ConstMatrix grad,
137 typename TTypes<Tindex>::ConstVec indices,
138 int64_t inner_dim);
139};
140
141template <typename Device, typename T>
142struct ApplyFtrl {
143 void operator()(const Device& d, typename TTypes<T>::Flat var,
144 typename TTypes<T>::Flat accum,
145 typename TTypes<T>::Flat linear,
146 typename TTypes<T>::ConstFlat grad,
147 typename TTypes<T>::ConstScalar lr,
148 typename TTypes<T>::ConstScalar l1,
149 typename TTypes<T>::ConstScalar l2,
150 typename TTypes<T>::ConstScalar lr_power);
151};
152
153template <typename Device, typename T>
154struct ApplyFtrlMultiplyLinearByLr {
155 void operator()(const Device& d, typename TTypes<T>::Flat var,
156 typename TTypes<T>::Flat accum,
157 typename TTypes<T>::Flat linear,
158 typename TTypes<T>::ConstFlat grad,
159 typename TTypes<T>::ConstScalar lr,
160 typename TTypes<T>::ConstScalar l1,
161 typename TTypes<T>::ConstScalar l2,
162 typename TTypes<T>::ConstScalar lr_power);
163};
164
165template <typename Device, typename T>
166struct ApplyFtrlV2 {
167 void operator()(const Device& d, typename TTypes<T>::Flat var,
168 typename TTypes<T>::Flat accum,
169 typename TTypes<T>::Flat linear,
170 typename TTypes<T>::ConstFlat grad,
171 typename TTypes<T>::ConstScalar lr,
172 typename TTypes<T>::ConstScalar l1,
173 typename TTypes<T>::ConstScalar l2,
174 typename TTypes<T>::ConstScalar l2_shrinkage,
175 typename TTypes<T>::ConstScalar lr_power);
176};
177
178template <typename Device, typename T>
179struct ApplyFtrlV2MultiplyLinearByLr {
180 void operator()(const Device& d, typename TTypes<T>::Flat var,
181 typename TTypes<T>::Flat accum,
182 typename TTypes<T>::Flat linear,
183 typename TTypes<T>::ConstFlat grad,
184 typename TTypes<T>::ConstScalar lr,
185 typename TTypes<T>::ConstScalar l1,
186 typename TTypes<T>::ConstScalar l2,
187 typename TTypes<T>::ConstScalar l2_shrinkage,
188 typename TTypes<T>::ConstScalar lr_power);
189};
190
191template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
192struct SparseApplyFtrl {
193 Status operator()(const Device& d, typename TTypes<T>::Matrix var_flat,
194 typename TTypes<T>::Matrix accum_flat,
195 typename TTypes<T>::Matrix linear_flat,
196 typename TTypes<T>::ConstScalar lr,
197 typename TTypes<T>::ConstScalar l1,
198 typename TTypes<T>::ConstScalar l2,
199 typename TTypes<T>::ConstScalar l2_shrinkage,
200 typename TTypes<T>::ConstScalar lr_power,
201 typename TTypes<T>::ConstMatrix grad_flat,
202 typename TTypes<Tindex>::ConstVec indices_vec,
203 int64_t inner_dim, bool multiply_linear_by_lr);
204};
205
206template <typename Device, typename T>
207struct ApplyMomentum {
208 void operator()(const Device& d, typename TTypes<T>::Flat var,
209 typename TTypes<T>::Flat accum,
210 typename TTypes<T>::ConstScalar lr,
211 typename TTypes<T>::ConstFlat grad,
212 typename TTypes<T>::ConstScalar momentum, bool use_nesterov);
213};
214
215template <typename Device, typename T>
216struct ApplyKerasMomentum {
217 void operator()(const Device& d, typename TTypes<T>::Flat var,
218 typename TTypes<T>::Flat accum,
219 typename TTypes<T>::ConstScalar lr,
220 typename TTypes<T>::ConstFlat grad,
221 typename TTypes<T>::ConstScalar momentum, bool use_nesterov);
222};
223
224template <typename Device, typename T, typename Tindex>
225struct SparseApplyKerasMomentum {
226 Tindex operator()(const Device& d, typename TTypes<T>::Matrix var,
227 typename TTypes<T>::Matrix accum,
228 typename TTypes<T>::ConstScalar lr,
229 typename TTypes<T>::ConstMatrix grad,
230 typename TTypes<Tindex>::ConstFlat indices,
231 typename TTypes<T>::ConstScalar momentum,
232 bool use_nesterov);
233};
234
235template <typename Device, typename T>
236struct ApplyAdam {
237 void operator()(const Device& d, typename TTypes<T>::Flat var,
238 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
239 typename TTypes<T>::ConstScalar beta1_power,
240 typename TTypes<T>::ConstScalar beta2_power,
241 typename TTypes<T>::ConstScalar lr,
242 typename TTypes<T>::ConstScalar beta1,
243 typename TTypes<T>::ConstScalar beta2,
244 typename TTypes<T>::ConstScalar epsilon,
245 typename TTypes<T>::ConstFlat grad, bool use_nesterov);
246};
247
248template <typename Device, typename T>
249struct ApplyAdamWithAmsgrad {
250 void operator()(const Device& d, typename TTypes<T>::Flat var,
251 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
252 typename TTypes<T>::Flat vhat,
253 typename TTypes<T>::ConstScalar beta1_power,
254 typename TTypes<T>::ConstScalar beta2_power,
255 typename TTypes<T>::ConstScalar lr,
256 typename TTypes<T>::ConstScalar beta1,
257 typename TTypes<T>::ConstScalar beta2,
258 typename TTypes<T>::ConstScalar epsilon,
259 typename TTypes<T>::ConstFlat grad);
260};
261
262template <typename Device, typename T>
263struct ApplyAdaMax {
264 void operator()(const Device& d, typename TTypes<T>::Flat var,
265 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
266 typename TTypes<T>::ConstScalar beta1_power,
267 typename TTypes<T>::ConstScalar lr,
268 typename TTypes<T>::ConstScalar beta1,
269 typename TTypes<T>::ConstScalar beta2,
270 typename TTypes<T>::ConstScalar epsilon,
271 typename TTypes<T>::ConstFlat grad);
272};
273
274template <typename Device, typename T>
275struct ApplyRMSProp {
276 void operator()(const Device& d, typename TTypes<T>::Flat var,
277 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
278 typename TTypes<T>::ConstScalar lr,
279 typename TTypes<T>::ConstScalar rho,
280 typename TTypes<T>::ConstScalar momentum,
281 typename TTypes<T>::ConstScalar epsilon,
282 typename TTypes<T>::ConstFlat grad);
283};
284
285template <typename Device, typename T>
286struct ApplyCenteredRMSProp {
287 void operator()(const Device& d, typename TTypes<T>::Flat var,
288 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
289 typename TTypes<T>::Flat mom,
290 typename TTypes<T>::ConstScalar lr,
291 typename TTypes<T>::ConstScalar rho,
292 typename TTypes<T>::ConstScalar momentum,
293 typename TTypes<T>::ConstScalar epsilon,
294 typename TTypes<T>::ConstFlat grad);
295};
296
297template <typename Device, typename T>
298struct ApplyAddSign {
299 void operator()(const Device& d, typename TTypes<T>::Flat var,
300 typename TTypes<T>::Flat m,
301 typename TTypes<T>::ConstScalar lr,
302 typename TTypes<T>::ConstScalar alpha,
303 typename TTypes<T>::ConstScalar sign_decay,
304 typename TTypes<T>::ConstScalar beta,
305 typename TTypes<T>::ConstFlat grad);
306};
307
308template <typename Device, typename T>
309struct ApplyPowerSign {
310 void operator()(const Device& d, typename TTypes<T>::Flat var,
311 typename TTypes<T>::Flat m,
312 typename TTypes<T>::ConstScalar lr,
313 typename TTypes<T>::ConstScalar logbase,
314 typename TTypes<T>::ConstScalar sign_decay,
315 typename TTypes<T>::ConstScalar beta,
316 typename TTypes<T>::ConstFlat grad);
317};
318
319} // end namespace functor
320} // end namespace tensorflow
321
322#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
323