1
2/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15==============================================================================*/
16
17// Produced by generate_tpu_embedding_load_retrieve_ops.py (Google-internal).
18
19#include <string>
20
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/lib/core/status.h"
23#include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
24#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
25
26namespace tensorflow {
27namespace tpu {
28
29using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
30
31REGISTER_OP("LoadTPUEmbeddingAdagradParameters")
32 .Input("parameters: float32")
33 .Input("accumulators: float32")
34 .Attr("table_id: int = -1")
35 .Attr("table_name: string = \"\"")
36 .Attr("num_shards: int")
37 .Attr("shard_id: int")
38 .Attr("config: string = \"\"")
39 .SetIsStateful()
40 .SetShapeFn(LoadOpShapeFunction());
41
42REGISTER_OP("RetrieveTPUEmbeddingAdagradParameters")
43 .Output("parameters: float32")
44 .Output("accumulators: float32")
45 .Attr("table_id: int = -1")
46 .Attr("table_name: string = \"\"")
47 .Attr("num_shards: int")
48 .Attr("shard_id: int")
49 .Attr("config: string = \"\"")
50 .SetIsStateful()
51 .SetShapeFn(RetrieveOpShapeFunction());
52
53REGISTER_OP("LoadTPUEmbeddingAdagradMomentumParameters")
54 .Input("parameters: float32")
55 .Input("accumulators: float32")
56 .Input("momenta: float32")
57 .Attr("table_id: int = -1")
58 .Attr("table_name: string = \"\"")
59 .Attr("num_shards: int")
60 .Attr("shard_id: int")
61 .Attr("config: string = \"\"")
62 .SetIsStateful()
63 .SetShapeFn(LoadOpShapeFunction());
64
65REGISTER_OP("RetrieveTPUEmbeddingAdagradMomentumParameters")
66 .Output("parameters: float32")
67 .Output("accumulators: float32")
68 .Output("momenta: float32")
69 .Attr("table_id: int = -1")
70 .Attr("table_name: string = \"\"")
71 .Attr("num_shards: int")
72 .Attr("shard_id: int")
73 .Attr("config: string = \"\"")
74 .SetIsStateful()
75 .SetShapeFn(RetrieveOpShapeFunction());
76
77REGISTER_OP("LoadTPUEmbeddingStochasticGradientDescentParameters")
78 .Input("parameters: float32")
79 .Attr("table_id: int = -1")
80 .Attr("table_name: string = \"\"")
81 .Attr("num_shards: int")
82 .Attr("shard_id: int")
83 .Attr("config: string = \"\"")
84 .SetIsStateful()
85 .SetShapeFn(LoadOpShapeFunction());
86
87REGISTER_OP("RetrieveTPUEmbeddingStochasticGradientDescentParameters")
88 .Output("parameters: float32")
89 .Attr("table_id: int = -1")
90 .Attr("table_name: string = \"\"")
91 .Attr("num_shards: int")
92 .Attr("shard_id: int")
93 .Attr("config: string = \"\"")
94 .SetIsStateful()
95 .SetShapeFn(RetrieveOpShapeFunction());
96
97REGISTER_OP("LoadTPUEmbeddingFTRLParameters")
98 .Input("parameters: float32")
99 .Input("accumulators: float32")
100 .Input("linears: float32")
101 .Attr("table_id: int = -1")
102 .Attr("table_name: string = \"\"")
103 .Attr("num_shards: int")
104 .Attr("shard_id: int")
105 .Attr("config: string = \"\"")
106 .SetIsStateful()
107 .SetShapeFn(LoadOpShapeFunction());
108
109REGISTER_OP("RetrieveTPUEmbeddingFTRLParameters")
110 .Output("parameters: float32")
111 .Output("accumulators: float32")
112 .Output("linears: float32")
113 .Attr("table_id: int = -1")
114 .Attr("table_name: string = \"\"")
115 .Attr("num_shards: int")
116 .Attr("shard_id: int")
117 .Attr("config: string = \"\"")
118 .SetIsStateful()
119 .SetShapeFn(RetrieveOpShapeFunction());
120
121REGISTER_OP("LoadTPUEmbeddingADAMParameters")
122 .Input("parameters: float32")
123 .Input("momenta: float32")
124 .Input("velocities: float32")
125 .Attr("table_id: int = -1")
126 .Attr("table_name: string = \"\"")
127 .Attr("num_shards: int")
128 .Attr("shard_id: int")
129 .Attr("config: string = \"\"")
130 .SetIsStateful()
131 .SetShapeFn(LoadOpShapeFunction());
132
133REGISTER_OP("RetrieveTPUEmbeddingADAMParameters")
134 .Output("parameters: float32")
135 .Output("momenta: float32")
136 .Output("velocities: float32")
137 .Attr("table_id: int = -1")
138 .Attr("table_name: string = \"\"")
139 .Attr("num_shards: int")
140 .Attr("shard_id: int")
141 .Attr("config: string = \"\"")
142 .SetIsStateful()
143 .SetShapeFn(RetrieveOpShapeFunction());
144
145REGISTER_OP("LoadTPUEmbeddingMomentumParameters")
146 .Input("parameters: float32")
147 .Input("momenta: float32")
148 .Attr("table_id: int = -1")
149 .Attr("table_name: string = \"\"")
150 .Attr("num_shards: int")
151 .Attr("shard_id: int")
152 .Attr("config: string = \"\"")
153 .SetIsStateful()
154 .SetShapeFn(LoadOpShapeFunction());
155
156REGISTER_OP("RetrieveTPUEmbeddingMomentumParameters")
157 .Output("parameters: float32")
158 .Output("momenta: float32")
159 .Attr("table_id: int = -1")
160 .Attr("table_name: string = \"\"")
161 .Attr("num_shards: int")
162 .Attr("shard_id: int")
163 .Attr("config: string = \"\"")
164 .SetIsStateful()
165 .SetShapeFn(RetrieveOpShapeFunction());
166
167REGISTER_OP("LoadTPUEmbeddingRMSPropParameters")
168 .Input("parameters: float32")
169 .Input("ms: float32")
170 .Input("mom: float32")
171 .Attr("table_id: int = -1")
172 .Attr("table_name: string = \"\"")
173 .Attr("num_shards: int")
174 .Attr("shard_id: int")
175 .Attr("config: string = \"\"")
176 .SetIsStateful()
177 .SetShapeFn(LoadOpShapeFunction());
178
179REGISTER_OP("RetrieveTPUEmbeddingRMSPropParameters")
180 .Output("parameters: float32")
181 .Output("ms: float32")
182 .Output("mom: float32")
183 .Attr("table_id: int = -1")
184 .Attr("table_name: string = \"\"")
185 .Attr("num_shards: int")
186 .Attr("shard_id: int")
187 .Attr("config: string = \"\"")
188 .SetIsStateful()
189 .SetShapeFn(RetrieveOpShapeFunction());
190
191REGISTER_OP("LoadTPUEmbeddingCenteredRMSPropParameters")
192 .Input("parameters: float32")
193 .Input("ms: float32")
194 .Input("mom: float32")
195 .Input("mg: float32")
196 .Attr("table_id: int = -1")
197 .Attr("table_name: string = \"\"")
198 .Attr("num_shards: int")
199 .Attr("shard_id: int")
200 .Attr("config: string = \"\"")
201 .SetIsStateful()
202 .SetShapeFn(LoadOpShapeFunction());
203
204REGISTER_OP("RetrieveTPUEmbeddingCenteredRMSPropParameters")
205 .Output("parameters: float32")
206 .Output("ms: float32")
207 .Output("mom: float32")
208 .Output("mg: float32")
209 .Attr("table_id: int = -1")
210 .Attr("table_name: string = \"\"")
211 .Attr("num_shards: int")
212 .Attr("shard_id: int")
213 .Attr("config: string = \"\"")
214 .SetIsStateful()
215 .SetShapeFn(RetrieveOpShapeFunction());
216
217REGISTER_OP("LoadTPUEmbeddingMDLAdagradLightParameters")
218 .Input("parameters: float32")
219 .Input("accumulators: float32")
220 .Input("weights: float32")
221 .Input("benefits: float32")
222 .Attr("table_id: int = -1")
223 .Attr("table_name: string = \"\"")
224 .Attr("num_shards: int")
225 .Attr("shard_id: int")
226 .Attr("config: string = \"\"")
227 .SetIsStateful()
228 .SetShapeFn(LoadOpShapeFunction());
229
230REGISTER_OP("RetrieveTPUEmbeddingMDLAdagradLightParameters")
231 .Output("parameters: float32")
232 .Output("accumulators: float32")
233 .Output("weights: float32")
234 .Output("benefits: float32")
235 .Attr("table_id: int = -1")
236 .Attr("table_name: string = \"\"")
237 .Attr("num_shards: int")
238 .Attr("shard_id: int")
239 .Attr("config: string = \"\"")
240 .SetIsStateful()
241 .SetShapeFn(RetrieveOpShapeFunction());
242
243REGISTER_OP("LoadTPUEmbeddingAdadeltaParameters")
244 .Input("parameters: float32")
245 .Input("accumulators: float32")
246 .Input("updates: float32")
247 .Attr("table_id: int = -1")
248 .Attr("table_name: string = \"\"")
249 .Attr("num_shards: int")
250 .Attr("shard_id: int")
251 .Attr("config: string = \"\"")
252 .SetIsStateful()
253 .SetShapeFn(LoadOpShapeFunction());
254
255REGISTER_OP("RetrieveTPUEmbeddingAdadeltaParameters")
256 .Output("parameters: float32")
257 .Output("accumulators: float32")
258 .Output("updates: float32")
259 .Attr("table_id: int = -1")
260 .Attr("table_name: string = \"\"")
261 .Attr("num_shards: int")
262 .Attr("shard_id: int")
263 .Attr("config: string = \"\"")
264 .SetIsStateful()
265 .SetShapeFn(RetrieveOpShapeFunction());
266
267REGISTER_OP("LoadTPUEmbeddingProximalAdagradParameters")
268 .Input("parameters: float32")
269 .Input("accumulators: float32")
270 .Attr("table_id: int = -1")
271 .Attr("table_name: string = \"\"")
272 .Attr("num_shards: int")
273 .Attr("shard_id: int")
274 .Attr("config: string = \"\"")
275 .SetIsStateful()
276 .SetShapeFn(LoadOpShapeFunction());
277
278REGISTER_OP("RetrieveTPUEmbeddingProximalAdagradParameters")
279 .Output("parameters: float32")
280 .Output("accumulators: float32")
281 .Attr("table_id: int = -1")
282 .Attr("table_name: string = \"\"")
283 .Attr("num_shards: int")
284 .Attr("shard_id: int")
285 .Attr("config: string = \"\"")
286 .SetIsStateful()
287 .SetShapeFn(RetrieveOpShapeFunction());
288
289REGISTER_OP("LoadTPUEmbeddingProximalYogiParameters")
290 .Input("parameters: float32")
291 .Input("v: float32")
292 .Input("m: float32")
293 .Attr("table_id: int = -1")
294 .Attr("table_name: string = \"\"")
295 .Attr("num_shards: int")
296 .Attr("shard_id: int")
297 .Attr("config: string = \"\"")
298 .SetIsStateful()
299 .SetShapeFn(LoadOpShapeFunction());
300
301REGISTER_OP("RetrieveTPUEmbeddingProximalYogiParameters")
302 .Output("parameters: float32")
303 .Output("v: float32")
304 .Output("m: float32")
305 .Attr("table_id: int = -1")
306 .Attr("table_name: string = \"\"")
307 .Attr("num_shards: int")
308 .Attr("shard_id: int")
309 .Attr("config: string = \"\"")
310 .SetIsStateful()
311 .SetShapeFn(RetrieveOpShapeFunction());
312
313REGISTER_OP("LoadTPUEmbeddingFrequencyEstimatorParameters")
314 .Input("parameters: float32")
315 .Input("last_hit_step: float32")
316 .Attr("table_id: int = -1")
317 .Attr("table_name: string = \"\"")
318 .Attr("num_shards: int")
319 .Attr("shard_id: int")
320 .Attr("config: string = \"\"")
321 .SetIsStateful()
322 .SetShapeFn(LoadOpShapeFunction());
323
324REGISTER_OP("RetrieveTPUEmbeddingFrequencyEstimatorParameters")
325 .Output("parameters: float32")
326 .Output("last_hit_step: float32")
327 .Attr("table_id: int = -1")
328 .Attr("table_name: string = \"\"")
329 .Attr("num_shards: int")
330 .Attr("shard_id: int")
331 .Attr("config: string = \"\"")
332 .SetIsStateful()
333 .SetShapeFn(RetrieveOpShapeFunction());
334
335} // namespace tpu
336} // namespace tensorflow
337