1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/tensor_types.h" |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | #include "tensorflow/core/platform/types.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace functor { |
26 | |
27 | // Each training algorithm has a ApplyXYZ functor struct declared in |
28 | // this header file. They are specialized for different devices |
29 | // (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc). |
30 | |
31 | template <typename Device, typename T> |
32 | struct ApplyGradientDescent { |
33 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
34 | typename TTypes<T>::ConstScalar alpha, |
35 | typename TTypes<T>::ConstFlat delta); |
36 | }; |
37 | |
38 | template <typename Device, typename T> |
39 | struct ApplyAdadelta { |
40 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
41 | typename TTypes<T>::Flat accum, |
42 | typename TTypes<T>::Flat accum_update, |
43 | typename TTypes<T>::ConstScalar lr, |
44 | typename TTypes<T>::ConstScalar rho, |
45 | typename TTypes<T>::ConstScalar epsilon, |
46 | typename TTypes<T>::ConstFlat grad); |
47 | }; |
48 | |
49 | template <typename Device, typename T, typename Tindex> |
50 | struct SparseApplyAdadelta { |
51 | void operator()(const Device& d, typename TTypes<T>::Matrix var, |
52 | typename TTypes<T>::Matrix accum, |
53 | typename TTypes<T>::Matrix accum_update, |
54 | typename TTypes<T>::ConstScalar lr, |
55 | typename TTypes<T>::ConstScalar rho, |
56 | typename TTypes<T>::ConstScalar epsilon, |
57 | typename TTypes<T>::ConstMatrix grad, |
58 | typename TTypes<Tindex>::ConstFlat indices); |
59 | }; |
60 | |
61 | template <typename Device, typename T> |
62 | struct FobosElasticNet { |
63 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
64 | typename TTypes<T>::ConstScalar lr, |
65 | typename TTypes<T>::ConstScalar l1, |
66 | typename TTypes<T>::ConstScalar l2, |
67 | typename TTypes<T>::ConstFlat grad); |
68 | }; |
69 | |
70 | template <typename Device, typename T> |
71 | struct ApplyProximalGradientDescent { |
72 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
73 | typename TTypes<T>::ConstScalar lr, |
74 | typename TTypes<T>::ConstScalar l1, |
75 | typename TTypes<T>::ConstScalar l2, |
76 | typename TTypes<T>::ConstFlat grad); |
77 | }; |
78 | |
79 | template <typename Device, typename T> |
80 | struct ApplyAdagrad { |
81 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
82 | typename TTypes<T>::Flat accum, |
83 | typename TTypes<T>::ConstScalar lr, |
84 | typename TTypes<T>::ConstFlat grad, bool update_slots); |
85 | }; |
86 | |
87 | template <typename Device, typename T> |
88 | struct ApplyAdagradV2 { |
89 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
90 | typename TTypes<T>::Flat accum, |
91 | typename TTypes<T>::ConstScalar lr, |
92 | typename TTypes<T>::ConstScalar epsilon, |
93 | typename TTypes<T>::ConstFlat grad, bool update_slots); |
94 | }; |
95 | |
96 | template <typename Device, typename T> |
97 | struct ApplyAdagradDA { |
98 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
99 | typename TTypes<T>::Flat gradient_accum, |
100 | typename TTypes<T>::Flat gradient_squared_accum, |
101 | typename TTypes<T>::ConstScalar lr, int64_t global_step, |
102 | typename TTypes<T>::ConstScalar l1, |
103 | typename TTypes<T>::ConstScalar l2, |
104 | typename TTypes<T>::ConstFlat grad); |
105 | }; |
106 | |
107 | template <typename Device, typename T, typename Tindex, bool has_epsilon> |
108 | struct SparseApplyAdagrad { |
109 | // Note that epsilon is ignored if has_epsilon is false. |
110 | Status operator()(const Device& d, typename TTypes<T>::Matrix var, |
111 | typename TTypes<T>::Matrix accum, |
112 | typename TTypes<T>::ConstScalar lr, |
113 | typename TTypes<T>::ConstScalar epsilon, |
114 | typename TTypes<T>::ConstMatrix grad, |
115 | typename TTypes<Tindex>::ConstVec indices, |
116 | int64_t inner_dim, bool update_slots); |
117 | }; |
118 | |
119 | template <typename Device, typename T> |
120 | struct ApplyProximalAdagrad { |
121 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
122 | typename TTypes<T>::Flat accum, |
123 | typename TTypes<T>::ConstScalar lr, |
124 | typename TTypes<T>::ConstScalar l1, |
125 | typename TTypes<T>::ConstScalar l2, |
126 | typename TTypes<T>::ConstFlat grad); |
127 | }; |
128 | |
129 | template <typename Device, typename T, typename Tindex> |
130 | struct SparseApplyProximalAdagrad { |
131 | Status operator()(const Device& d, typename TTypes<T>::Matrix var, |
132 | typename TTypes<T>::Matrix accum, |
133 | typename TTypes<T>::ConstScalar lr, |
134 | typename TTypes<T>::ConstScalar l1, |
135 | typename TTypes<T>::ConstScalar l2, |
136 | typename TTypes<T>::ConstMatrix grad, |
137 | typename TTypes<Tindex>::ConstVec indices, |
138 | int64_t inner_dim); |
139 | }; |
140 | |
141 | template <typename Device, typename T> |
142 | struct ApplyFtrl { |
143 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
144 | typename TTypes<T>::Flat accum, |
145 | typename TTypes<T>::Flat linear, |
146 | typename TTypes<T>::ConstFlat grad, |
147 | typename TTypes<T>::ConstScalar lr, |
148 | typename TTypes<T>::ConstScalar l1, |
149 | typename TTypes<T>::ConstScalar l2, |
150 | typename TTypes<T>::ConstScalar lr_power); |
151 | }; |
152 | |
153 | template <typename Device, typename T> |
154 | struct ApplyFtrlMultiplyLinearByLr { |
155 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
156 | typename TTypes<T>::Flat accum, |
157 | typename TTypes<T>::Flat linear, |
158 | typename TTypes<T>::ConstFlat grad, |
159 | typename TTypes<T>::ConstScalar lr, |
160 | typename TTypes<T>::ConstScalar l1, |
161 | typename TTypes<T>::ConstScalar l2, |
162 | typename TTypes<T>::ConstScalar lr_power); |
163 | }; |
164 | |
165 | template <typename Device, typename T> |
166 | struct ApplyFtrlV2 { |
167 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
168 | typename TTypes<T>::Flat accum, |
169 | typename TTypes<T>::Flat linear, |
170 | typename TTypes<T>::ConstFlat grad, |
171 | typename TTypes<T>::ConstScalar lr, |
172 | typename TTypes<T>::ConstScalar l1, |
173 | typename TTypes<T>::ConstScalar l2, |
174 | typename TTypes<T>::ConstScalar l2_shrinkage, |
175 | typename TTypes<T>::ConstScalar lr_power); |
176 | }; |
177 | |
178 | template <typename Device, typename T> |
179 | struct ApplyFtrlV2MultiplyLinearByLr { |
180 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
181 | typename TTypes<T>::Flat accum, |
182 | typename TTypes<T>::Flat linear, |
183 | typename TTypes<T>::ConstFlat grad, |
184 | typename TTypes<T>::ConstScalar lr, |
185 | typename TTypes<T>::ConstScalar l1, |
186 | typename TTypes<T>::ConstScalar l2, |
187 | typename TTypes<T>::ConstScalar l2_shrinkage, |
188 | typename TTypes<T>::ConstScalar lr_power); |
189 | }; |
190 | |
191 | template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage> |
192 | struct SparseApplyFtrl { |
193 | Status operator()(const Device& d, typename TTypes<T>::Matrix var_flat, |
194 | typename TTypes<T>::Matrix accum_flat, |
195 | typename TTypes<T>::Matrix linear_flat, |
196 | typename TTypes<T>::ConstScalar lr, |
197 | typename TTypes<T>::ConstScalar l1, |
198 | typename TTypes<T>::ConstScalar l2, |
199 | typename TTypes<T>::ConstScalar l2_shrinkage, |
200 | typename TTypes<T>::ConstScalar lr_power, |
201 | typename TTypes<T>::ConstMatrix grad_flat, |
202 | typename TTypes<Tindex>::ConstVec indices_vec, |
203 | int64_t inner_dim, bool multiply_linear_by_lr); |
204 | }; |
205 | |
206 | template <typename Device, typename T> |
207 | struct ApplyMomentum { |
208 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
209 | typename TTypes<T>::Flat accum, |
210 | typename TTypes<T>::ConstScalar lr, |
211 | typename TTypes<T>::ConstFlat grad, |
212 | typename TTypes<T>::ConstScalar momentum, bool use_nesterov); |
213 | }; |
214 | |
215 | template <typename Device, typename T> |
216 | struct ApplyKerasMomentum { |
217 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
218 | typename TTypes<T>::Flat accum, |
219 | typename TTypes<T>::ConstScalar lr, |
220 | typename TTypes<T>::ConstFlat grad, |
221 | typename TTypes<T>::ConstScalar momentum, bool use_nesterov); |
222 | }; |
223 | |
224 | template <typename Device, typename T, typename Tindex> |
225 | struct SparseApplyKerasMomentum { |
226 | Tindex operator()(const Device& d, typename TTypes<T>::Matrix var, |
227 | typename TTypes<T>::Matrix accum, |
228 | typename TTypes<T>::ConstScalar lr, |
229 | typename TTypes<T>::ConstMatrix grad, |
230 | typename TTypes<Tindex>::ConstFlat indices, |
231 | typename TTypes<T>::ConstScalar momentum, |
232 | bool use_nesterov); |
233 | }; |
234 | |
235 | template <typename Device, typename T> |
236 | struct ApplyAdam { |
237 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
238 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, |
239 | typename TTypes<T>::ConstScalar beta1_power, |
240 | typename TTypes<T>::ConstScalar beta2_power, |
241 | typename TTypes<T>::ConstScalar lr, |
242 | typename TTypes<T>::ConstScalar beta1, |
243 | typename TTypes<T>::ConstScalar beta2, |
244 | typename TTypes<T>::ConstScalar epsilon, |
245 | typename TTypes<T>::ConstFlat grad, bool use_nesterov); |
246 | }; |
247 | |
248 | template <typename Device, typename T> |
249 | struct ApplyAdamWithAmsgrad { |
250 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
251 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, |
252 | typename TTypes<T>::Flat vhat, |
253 | typename TTypes<T>::ConstScalar beta1_power, |
254 | typename TTypes<T>::ConstScalar beta2_power, |
255 | typename TTypes<T>::ConstScalar lr, |
256 | typename TTypes<T>::ConstScalar beta1, |
257 | typename TTypes<T>::ConstScalar beta2, |
258 | typename TTypes<T>::ConstScalar epsilon, |
259 | typename TTypes<T>::ConstFlat grad); |
260 | }; |
261 | |
262 | template <typename Device, typename T> |
263 | struct ApplyAdaMax { |
264 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
265 | typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, |
266 | typename TTypes<T>::ConstScalar beta1_power, |
267 | typename TTypes<T>::ConstScalar lr, |
268 | typename TTypes<T>::ConstScalar beta1, |
269 | typename TTypes<T>::ConstScalar beta2, |
270 | typename TTypes<T>::ConstScalar epsilon, |
271 | typename TTypes<T>::ConstFlat grad); |
272 | }; |
273 | |
274 | template <typename Device, typename T> |
275 | struct ApplyRMSProp { |
276 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
277 | typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, |
278 | typename TTypes<T>::ConstScalar lr, |
279 | typename TTypes<T>::ConstScalar rho, |
280 | typename TTypes<T>::ConstScalar momentum, |
281 | typename TTypes<T>::ConstScalar epsilon, |
282 | typename TTypes<T>::ConstFlat grad); |
283 | }; |
284 | |
285 | template <typename Device, typename T> |
286 | struct ApplyCenteredRMSProp { |
287 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
288 | typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, |
289 | typename TTypes<T>::Flat mom, |
290 | typename TTypes<T>::ConstScalar lr, |
291 | typename TTypes<T>::ConstScalar rho, |
292 | typename TTypes<T>::ConstScalar momentum, |
293 | typename TTypes<T>::ConstScalar epsilon, |
294 | typename TTypes<T>::ConstFlat grad); |
295 | }; |
296 | |
297 | template <typename Device, typename T> |
298 | struct ApplyAddSign { |
299 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
300 | typename TTypes<T>::Flat m, |
301 | typename TTypes<T>::ConstScalar lr, |
302 | typename TTypes<T>::ConstScalar alpha, |
303 | typename TTypes<T>::ConstScalar sign_decay, |
304 | typename TTypes<T>::ConstScalar beta, |
305 | typename TTypes<T>::ConstFlat grad); |
306 | }; |
307 | |
308 | template <typename Device, typename T> |
309 | struct ApplyPowerSign { |
310 | void operator()(const Device& d, typename TTypes<T>::Flat var, |
311 | typename TTypes<T>::Flat m, |
312 | typename TTypes<T>::ConstScalar lr, |
313 | typename TTypes<T>::ConstScalar logbase, |
314 | typename TTypes<T>::ConstScalar sign_decay, |
315 | typename TTypes<T>::ConstScalar beta, |
316 | typename TTypes<T>::ConstFlat grad); |
317 | }; |
318 | |
319 | } // end namespace functor |
320 | } // end namespace tensorflow |
321 | |
322 | #endif // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ |
323 | |