1 | |
---|---|
2 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
3 | |
4 | Licensed under the Apache License, Version 2.0 (the "License"); |
5 | you may not use this file except in compliance with the License. |
6 | You may obtain a copy of the License at |
7 | |
8 | http://www.apache.org/licenses/LICENSE-2.0 |
9 | |
10 | Unless required by applicable law or agreed to in writing, software |
11 | distributed under the License is distributed on an "AS IS" BASIS, |
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | See the License for the specific language governing permissions and |
14 | limitations under the License. |
15 | ==============================================================================*/ |
16 | |
17 | // Produced by generate_tpu_embedding_load_retrieve_ops.py (Google-internal). |
18 | |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/framework/op.h" |
22 | #include "tensorflow/core/lib/core/status.h" |
23 | #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h" |
24 | #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace tpu { |
28 | |
29 | using OptimizationAlgorithm = OptimizationParameters::ParametersCase; |
30 | |
31 | REGISTER_OP("LoadTPUEmbeddingAdagradParameters") |
32 | .Input("parameters: float32") |
33 | .Input("accumulators: float32") |
34 | .Attr("table_id: int = -1") |
35 | .Attr("table_name: string = \"\"") |
36 | .Attr("num_shards: int") |
37 | .Attr("shard_id: int") |
38 | .Attr("config: string = \"\"") |
39 | .SetIsStateful() |
40 | .SetShapeFn(LoadOpShapeFunction()); |
41 | |
42 | REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters") |
43 | .Output("parameters: float32") |
44 | .Output("accumulators: float32") |
45 | .Attr("table_id: int = -1") |
46 | .Attr("table_name: string = \"\"") |
47 | .Attr("num_shards: int") |
48 | .Attr("shard_id: int") |
49 | .Attr("config: string = \"\"") |
50 | .SetIsStateful() |
51 | .SetShapeFn(RetrieveOpShapeFunction()); |
52 | |
53 | REGISTER_OP("LoadTPUEmbeddingAdagradMomentumParameters") |
54 | .Input("parameters: float32") |
55 | .Input("accumulators: float32") |
56 | .Input("momenta: float32") |
57 | .Attr("table_id: int = -1") |
58 | .Attr("table_name: string = \"\"") |
59 | .Attr("num_shards: int") |
60 | .Attr("shard_id: int") |
61 | .Attr("config: string = \"\"") |
62 | .SetIsStateful() |
63 | .SetShapeFn(LoadOpShapeFunction()); |
64 | |
65 | REGISTER_OP("RetrieveTPUEmbeddingAdagradMomentumParameters") |
66 | .Output("parameters: float32") |
67 | .Output("accumulators: float32") |
68 | .Output("momenta: float32") |
69 | .Attr("table_id: int = -1") |
70 | .Attr("table_name: string = \"\"") |
71 | .Attr("num_shards: int") |
72 | .Attr("shard_id: int") |
73 | .Attr("config: string = \"\"") |
74 | .SetIsStateful() |
75 | .SetShapeFn(RetrieveOpShapeFunction()); |
76 | |
77 | REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters") |
78 | .Input("parameters: float32") |
79 | .Attr("table_id: int = -1") |
80 | .Attr("table_name: string = \"\"") |
81 | .Attr("num_shards: int") |
82 | .Attr("shard_id: int") |
83 | .Attr("config: string = \"\"") |
84 | .SetIsStateful() |
85 | .SetShapeFn(LoadOpShapeFunction()); |
86 | |
87 | REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters") |
88 | .Output("parameters: float32") |
89 | .Attr("table_id: int = -1") |
90 | .Attr("table_name: string = \"\"") |
91 | .Attr("num_shards: int") |
92 | .Attr("shard_id: int") |
93 | .Attr("config: string = \"\"") |
94 | .SetIsStateful() |
95 | .SetShapeFn(RetrieveOpShapeFunction()); |
96 | |
97 | REGISTER_OP("LoadTPUEmbeddingFTRLParameters") |
98 | .Input("parameters: float32") |
99 | .Input("accumulators: float32") |
100 | .Input("linears: float32") |
101 | .Attr("table_id: int = -1") |
102 | .Attr("table_name: string = \"\"") |
103 | .Attr("num_shards: int") |
104 | .Attr("shard_id: int") |
105 | .Attr("config: string = \"\"") |
106 | .SetIsStateful() |
107 | .SetShapeFn(LoadOpShapeFunction()); |
108 | |
109 | REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters") |
110 | .Output("parameters: float32") |
111 | .Output("accumulators: float32") |
112 | .Output("linears: float32") |
113 | .Attr("table_id: int = -1") |
114 | .Attr("table_name: string = \"\"") |
115 | .Attr("num_shards: int") |
116 | .Attr("shard_id: int") |
117 | .Attr("config: string = \"\"") |
118 | .SetIsStateful() |
119 | .SetShapeFn(RetrieveOpShapeFunction()); |
120 | |
121 | REGISTER_OP("LoadTPUEmbeddingADAMParameters") |
122 | .Input("parameters: float32") |
123 | .Input("momenta: float32") |
124 | .Input("velocities: float32") |
125 | .Attr("table_id: int = -1") |
126 | .Attr("table_name: string = \"\"") |
127 | .Attr("num_shards: int") |
128 | .Attr("shard_id: int") |
129 | .Attr("config: string = \"\"") |
130 | .SetIsStateful() |
131 | .SetShapeFn(LoadOpShapeFunction()); |
132 | |
133 | REGISTER_OP("RetrieveTPUEmbeddingADAMParameters") |
134 | .Output("parameters: float32") |
135 | .Output("momenta: float32") |
136 | .Output("velocities: float32") |
137 | .Attr("table_id: int = -1") |
138 | .Attr("table_name: string = \"\"") |
139 | .Attr("num_shards: int") |
140 | .Attr("shard_id: int") |
141 | .Attr("config: string = \"\"") |
142 | .SetIsStateful() |
143 | .SetShapeFn(RetrieveOpShapeFunction()); |
144 | |
145 | REGISTER_OP("LoadTPUEmbeddingMomentumParameters") |
146 | .Input("parameters: float32") |
147 | .Input("momenta: float32") |
148 | .Attr("table_id: int = -1") |
149 | .Attr("table_name: string = \"\"") |
150 | .Attr("num_shards: int") |
151 | .Attr("shard_id: int") |
152 | .Attr("config: string = \"\"") |
153 | .SetIsStateful() |
154 | .SetShapeFn(LoadOpShapeFunction()); |
155 | |
156 | REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters") |
157 | .Output("parameters: float32") |
158 | .Output("momenta: float32") |
159 | .Attr("table_id: int = -1") |
160 | .Attr("table_name: string = \"\"") |
161 | .Attr("num_shards: int") |
162 | .Attr("shard_id: int") |
163 | .Attr("config: string = \"\"") |
164 | .SetIsStateful() |
165 | .SetShapeFn(RetrieveOpShapeFunction()); |
166 | |
167 | REGISTER_OP("LoadTPUEmbeddingRMSPropParameters") |
168 | .Input("parameters: float32") |
169 | .Input("ms: float32") |
170 | .Input("mom: float32") |
171 | .Attr("table_id: int = -1") |
172 | .Attr("table_name: string = \"\"") |
173 | .Attr("num_shards: int") |
174 | .Attr("shard_id: int") |
175 | .Attr("config: string = \"\"") |
176 | .SetIsStateful() |
177 | .SetShapeFn(LoadOpShapeFunction()); |
178 | |
179 | REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters") |
180 | .Output("parameters: float32") |
181 | .Output("ms: float32") |
182 | .Output("mom: float32") |
183 | .Attr("table_id: int = -1") |
184 | .Attr("table_name: string = \"\"") |
185 | .Attr("num_shards: int") |
186 | .Attr("shard_id: int") |
187 | .Attr("config: string = \"\"") |
188 | .SetIsStateful() |
189 | .SetShapeFn(RetrieveOpShapeFunction()); |
190 | |
191 | REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters") |
192 | .Input("parameters: float32") |
193 | .Input("ms: float32") |
194 | .Input("mom: float32") |
195 | .Input("mg: float32") |
196 | .Attr("table_id: int = -1") |
197 | .Attr("table_name: string = \"\"") |
198 | .Attr("num_shards: int") |
199 | .Attr("shard_id: int") |
200 | .Attr("config: string = \"\"") |
201 | .SetIsStateful() |
202 | .SetShapeFn(LoadOpShapeFunction()); |
203 | |
204 | REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters") |
205 | .Output("parameters: float32") |
206 | .Output("ms: float32") |
207 | .Output("mom: float32") |
208 | .Output("mg: float32") |
209 | .Attr("table_id: int = -1") |
210 | .Attr("table_name: string = \"\"") |
211 | .Attr("num_shards: int") |
212 | .Attr("shard_id: int") |
213 | .Attr("config: string = \"\"") |
214 | .SetIsStateful() |
215 | .SetShapeFn(RetrieveOpShapeFunction()); |
216 | |
217 | REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters") |
218 | .Input("parameters: float32") |
219 | .Input("accumulators: float32") |
220 | .Input("weights: float32") |
221 | .Input("benefits: float32") |
222 | .Attr("table_id: int = -1") |
223 | .Attr("table_name: string = \"\"") |
224 | .Attr("num_shards: int") |
225 | .Attr("shard_id: int") |
226 | .Attr("config: string = \"\"") |
227 | .SetIsStateful() |
228 | .SetShapeFn(LoadOpShapeFunction()); |
229 | |
230 | REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters") |
231 | .Output("parameters: float32") |
232 | .Output("accumulators: float32") |
233 | .Output("weights: float32") |
234 | .Output("benefits: float32") |
235 | .Attr("table_id: int = -1") |
236 | .Attr("table_name: string = \"\"") |
237 | .Attr("num_shards: int") |
238 | .Attr("shard_id: int") |
239 | .Attr("config: string = \"\"") |
240 | .SetIsStateful() |
241 | .SetShapeFn(RetrieveOpShapeFunction()); |
242 | |
243 | REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters") |
244 | .Input("parameters: float32") |
245 | .Input("accumulators: float32") |
246 | .Input("updates: float32") |
247 | .Attr("table_id: int = -1") |
248 | .Attr("table_name: string = \"\"") |
249 | .Attr("num_shards: int") |
250 | .Attr("shard_id: int") |
251 | .Attr("config: string = \"\"") |
252 | .SetIsStateful() |
253 | .SetShapeFn(LoadOpShapeFunction()); |
254 | |
255 | REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters") |
256 | .Output("parameters: float32") |
257 | .Output("accumulators: float32") |
258 | .Output("updates: float32") |
259 | .Attr("table_id: int = -1") |
260 | .Attr("table_name: string = \"\"") |
261 | .Attr("num_shards: int") |
262 | .Attr("shard_id: int") |
263 | .Attr("config: string = \"\"") |
264 | .SetIsStateful() |
265 | .SetShapeFn(RetrieveOpShapeFunction()); |
266 | |
267 | REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters") |
268 | .Input("parameters: float32") |
269 | .Input("accumulators: float32") |
270 | .Attr("table_id: int = -1") |
271 | .Attr("table_name: string = \"\"") |
272 | .Attr("num_shards: int") |
273 | .Attr("shard_id: int") |
274 | .Attr("config: string = \"\"") |
275 | .SetIsStateful() |
276 | .SetShapeFn(LoadOpShapeFunction()); |
277 | |
278 | REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters") |
279 | .Output("parameters: float32") |
280 | .Output("accumulators: float32") |
281 | .Attr("table_id: int = -1") |
282 | .Attr("table_name: string = \"\"") |
283 | .Attr("num_shards: int") |
284 | .Attr("shard_id: int") |
285 | .Attr("config: string = \"\"") |
286 | .SetIsStateful() |
287 | .SetShapeFn(RetrieveOpShapeFunction()); |
288 | |
289 | REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters") |
290 | .Input("parameters: float32") |
291 | .Input("v: float32") |
292 | .Input("m: float32") |
293 | .Attr("table_id: int = -1") |
294 | .Attr("table_name: string = \"\"") |
295 | .Attr("num_shards: int") |
296 | .Attr("shard_id: int") |
297 | .Attr("config: string = \"\"") |
298 | .SetIsStateful() |
299 | .SetShapeFn(LoadOpShapeFunction()); |
300 | |
301 | REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters") |
302 | .Output("parameters: float32") |
303 | .Output("v: float32") |
304 | .Output("m: float32") |
305 | .Attr("table_id: int = -1") |
306 | .Attr("table_name: string = \"\"") |
307 | .Attr("num_shards: int") |
308 | .Attr("shard_id: int") |
309 | .Attr("config: string = \"\"") |
310 | .SetIsStateful() |
311 | .SetShapeFn(RetrieveOpShapeFunction()); |
312 | |
313 | REGISTER_OP("LoadTPUEmbeddingFrequencyEstimatorParameters") |
314 | .Input("parameters: float32") |
315 | .Input("last_hit_step: float32") |
316 | .Attr("table_id: int = -1") |
317 | .Attr("table_name: string = \"\"") |
318 | .Attr("num_shards: int") |
319 | .Attr("shard_id: int") |
320 | .Attr("config: string = \"\"") |
321 | .SetIsStateful() |
322 | .SetShapeFn(LoadOpShapeFunction()); |
323 | |
324 | REGISTER_OP("RetrieveTPUEmbeddingFrequencyEstimatorParameters") |
325 | .Output("parameters: float32") |
326 | .Output("last_hit_step: float32") |
327 | .Attr("table_id: int = -1") |
328 | .Attr("table_name: string = \"\"") |
329 | .Attr("num_shards: int") |
330 | .Attr("shard_id: int") |
331 | .Attr("config: string = \"\"") |
332 | .SetIsStateful() |
333 | .SetShapeFn(RetrieveOpShapeFunction()); |
334 | |
335 | } // namespace tpu |
336 | } // namespace tensorflow |
337 |