1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
17
18#include <string>
19#include <utility>
20
21#include "tensorflow/compiler/xla/service/hlo.pb.h"
22#include "tensorflow/compiler/xla/service/hlo_opcode.h"
23#include "tensorflow/compiler/xla/xla_data.pb.h"
24#include "tensorflow/core/framework/attr_value.pb.h"
25#include "tensorflow/core/framework/shape_inference.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/strings/stringprintf.h"
28
29namespace tensorflow {
30namespace tpu {
31
32std::string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
33 switch (alg) {
34 case OptimizationAlgorithm::kAdagrad:
35 return "Adagrad";
36 case OptimizationAlgorithm::kAdagradMomentum:
37 return "AdagradMomentum";
38 case OptimizationAlgorithm::kBoundedAdagrad:
39 return "BoundedAdagrad";
40 case OptimizationAlgorithm::kStochasticGradientDescent:
41 return "StochasticGradientDescent";
42 case OptimizationAlgorithm::kFtrl:
43 return "FTRL";
44 case OptimizationAlgorithm::kAdam:
45 return "ADAM";
46 case OptimizationAlgorithm::kMomentum:
47 return "Momentum";
48 case OptimizationAlgorithm::kRmsProp:
49 return "RMSProp";
50 case OptimizationAlgorithm::kCenteredRmsProp:
51 return "CenteredRMSProp";
52 case OptimizationAlgorithm::kMdlAdagradLight:
53 return "MDLAdagradLight";
54 case OptimizationAlgorithm::kAdadelta:
55 return "Adadelta";
56 case OptimizationAlgorithm::kProximalAdagrad:
57 return "ProximalAdagrad";
58 case OptimizationAlgorithm::kOnlineYogi:
59 return "OnlineYogi";
60 case OptimizationAlgorithm::kProximalYogi:
61 return "ProximalYogi";
62 case OptimizationAlgorithm::kFrequencyEstimator:
63 return "FrequencyEstimator";
64 case OptimizationAlgorithm::kUserDefinedProgram:
65 return "UserDefinedProgram";
66 case OptimizationAlgorithm::kAssign:
67 return "Assign";
68 case OptimizationAlgorithm::PARAMETERS_NOT_SET:
69 return "*** Not set ***";
70 }
71 return "*** Not set ***";
72}
73
74std::string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
75 switch (alg) {
76 case OptimizationAlgorithm::kAdagrad:
77 return "Adagrad";
78 case OptimizationAlgorithm::kAdagradMomentum:
79 return "Adagrad with Momentum";
80 case OptimizationAlgorithm::kBoundedAdagrad:
81 return "Bounded Adagrad";
82 case OptimizationAlgorithm::kStochasticGradientDescent:
83 return "stochastic gradient descent";
84 case OptimizationAlgorithm::kFtrl:
85 return "FTRL";
86 case OptimizationAlgorithm::kAdam:
87 return "ADAM";
88 case OptimizationAlgorithm::kMomentum:
89 return "Momentum";
90 case OptimizationAlgorithm::kRmsProp:
91 return "RMSProp";
92 case OptimizationAlgorithm::kCenteredRmsProp:
93 return "centered RMSProp";
94 case OptimizationAlgorithm::kMdlAdagradLight:
95 return "MDL Adagrad Light";
96 case OptimizationAlgorithm::kAdadelta:
97 return "Adadelta";
98 case OptimizationAlgorithm::kProximalAdagrad:
99 return "proximal Adagrad";
100 case OptimizationAlgorithm::kOnlineYogi:
101 return "online Yogi";
102 case OptimizationAlgorithm::kProximalYogi:
103 return "proximal Yogi";
104 case OptimizationAlgorithm::kFrequencyEstimator:
105 return "frequency estimator";
106 case OptimizationAlgorithm::kUserDefinedProgram:
107 return "UserDefinedProgram";
108 case OptimizationAlgorithm::kAssign:
109 return "Assign";
110 case OptimizationAlgorithm::PARAMETERS_NOT_SET:
111 return "unknown (not specified)";
112 }
113 return "unknown (not specified)";
114}
115
116// Returns the number of optimization parameter vectors used by the optimization
117// algorithm, excluding the weights themselves and assuming no gradient
118// accumulation.
119Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
120 int* count) {
121 switch (params.parameters_case()) {
122 case OptimizationAlgorithm::kAdagrad:
123 *count = 1;
124 return OkStatus();
125 case OptimizationAlgorithm::kAdagradMomentum:
126 *count = 2;
127 return OkStatus();
128 case OptimizationAlgorithm::kBoundedAdagrad:
129 *count = 1;
130 return OkStatus();
131 case OptimizationAlgorithm::kStochasticGradientDescent:
132 *count = 0;
133 return OkStatus();
134 case OptimizationAlgorithm::kFtrl:
135 *count = 2;
136 return OkStatus();
137 case OptimizationAlgorithm::kAdam:
138 *count = 2;
139 return OkStatus();
140 case OptimizationAlgorithm::kMomentum:
141 *count = 1;
142 return OkStatus();
143 case OptimizationAlgorithm::kRmsProp:
144 *count = 2;
145 return OkStatus();
146 case OptimizationAlgorithm::kCenteredRmsProp:
147 *count = 3;
148 return OkStatus();
149 case OptimizationAlgorithm::kMdlAdagradLight:
150 *count = 3;
151 return OkStatus();
152 case OptimizationAlgorithm::kAdadelta:
153 *count = 2;
154 return OkStatus();
155 case OptimizationAlgorithm::kProximalAdagrad:
156 *count = 1;
157 return OkStatus();
158 case OptimizationAlgorithm::kOnlineYogi:
159 *count = 2;
160 return OkStatus();
161 case OptimizationAlgorithm::kProximalYogi:
162 *count = 2;
163 return OkStatus();
164 case OptimizationAlgorithm::kFrequencyEstimator:
165 *count = 1;
166 return OkStatus();
167 case OptimizationAlgorithm::kUserDefinedProgram: {
168 const xla::ProgramShapeProto& program_shape =
169 params.user_defined_program().program().host_program_shape();
170
171 const int num_inputs = program_shape.parameters_size();
172 const int num_outputs = program_shape.result().tuple_shapes_size();
173
174 if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
175 (num_inputs != num_outputs + 2))) {
176 return errors::InvalidArgument(
177 "User-defined TPU embedding optimizer program must have at least "
178 "two inputs and the number of outputs must be 1 or 2 less than the "
179 "number of inputs. Received ",
180 num_inputs, " input(s) and ", num_outputs, "output(s).");
181 }
182
183 *count = num_outputs - 1;
184
185 return OkStatus();
186 }
187 case OptimizationAlgorithm::kAssign:
188 *count = 0;
189 return OkStatus();
190 case OptimizationAlgorithm::PARAMETERS_NOT_SET:
191 return errors::InvalidArgument("No optimization algorithm specified");
192 }
193 return errors::InvalidArgument("No optimization algorithm specified");
194}
195
196Status GetGradientAccumulationSupport(const OptimizationParameters& params,
197 GradientAccumulationSupport* support) {
198 int auxiliary_parameter_count;
199 TF_RETURN_IF_ERROR(
200 GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
201 *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
202 ? GradientAccumulationSupport::kSupported
203 : GradientAccumulationSupport::kNotSupported;
204 return OkStatus();
205}
206
207Status UseGradientAccumulation(const OptimizationParameters& params,
208 bool* use_gradient_accumulation) {
209 GradientAccumulationSupport support;
210 TF_RETURN_IF_ERROR(GetGradientAccumulationSupport(params, &support));
211 bool raw_gradient_accumulation_status = false;
212 switch (params.gradient_accumulation_status()) {
213 case GradientAccumulationStatus::UNSPECIFIED: {
214 // Default is now to turn gradient accumulation on by default.
215 raw_gradient_accumulation_status = true;
216 break;
217 }
218 case GradientAccumulationStatus::DISABLED: {
219 raw_gradient_accumulation_status = false;
220 break;
221 }
222 case GradientAccumulationStatus::ENABLED: {
223 raw_gradient_accumulation_status = true;
224 break;
225 }
226 default:
227 return errors::Internal(
228 absl::StrCat("Unsupported gradient accumulation status ",
229 GradientAccumulationStatus_Status_Name(
230 params.gradient_accumulation_status())));
231 }
232 switch (support) {
233 case GradientAccumulationSupport::kSupported: {
234 *use_gradient_accumulation = raw_gradient_accumulation_status;
235 break;
236 }
237 case GradientAccumulationSupport::kNotSupported: {
238 if (raw_gradient_accumulation_status) {
239 return errors::InvalidArgument(strings::Printf(
240 "Optimization algorithm %s does not support gradient accumulation "
241 "but parameters specify it.",
242 GetOptimizationAlgorithmName(params.parameters_case()).c_str()));
243 }
244 *use_gradient_accumulation = false;
245 break;
246 }
247 }
248 return OkStatus();
249}
250
251Status GetOptimizationAlgorithmStateVariables(
252 const OptimizationParameters& params,
253 std::vector<StateVariableSpecification>* state_variables) {
254 // The parameter set for the weights themselves is required to be named
255 // "parameters". The rest should stay stable for compatibility. There is an
256 // internal function, GetOptimizationAlgorithmStateVariableInternalIndices,
257 // that needs to be updated along with this one.
258 bool use_gradient_accumulation;
259 TF_RETURN_IF_ERROR(
260 UseGradientAccumulation(params, &use_gradient_accumulation));
261
262 auto add_state_variable = [&](const std::string& name) {
263 StateVariableSpecification spec;
264 spec.set_name(name);
265 (void)spec.mutable_user_defined();
266 state_variables->push_back(spec);
267 };
268
269 switch (params.parameters_case()) {
270 case OptimizationAlgorithm::kAdagrad: {
271 add_state_variable("parameters");
272 add_state_variable("accumulators");
273 break;
274 }
275 case OptimizationAlgorithm::kAdagradMomentum: {
276 add_state_variable("parameters");
277 add_state_variable("accumulators");
278 add_state_variable("momenta");
279 break;
280 }
281 case OptimizationAlgorithm::kBoundedAdagrad: {
282 add_state_variable("parameters");
283 add_state_variable("accumulators");
284 break;
285 }
286 case OptimizationAlgorithm::kStochasticGradientDescent: {
287 add_state_variable("parameters");
288 break;
289 }
290 case OptimizationAlgorithm::kFtrl: {
291 add_state_variable("parameters");
292 add_state_variable("accumulators");
293 add_state_variable("linears");
294 break;
295 }
296 case OptimizationAlgorithm::kAdam: {
297 add_state_variable("parameters");
298 add_state_variable("momenta");
299 add_state_variable("velocities");
300 break;
301 }
302 case OptimizationAlgorithm::kMomentum: {
303 add_state_variable("parameters");
304 add_state_variable("momenta");
305 break;
306 }
307 case OptimizationAlgorithm::kRmsProp: {
308 add_state_variable("parameters");
309 add_state_variable("ms");
310 add_state_variable("mom");
311 break;
312 }
313 case OptimizationAlgorithm::kCenteredRmsProp: {
314 add_state_variable("parameters");
315 add_state_variable("ms");
316 add_state_variable("mom");
317 add_state_variable("mg");
318 break;
319 }
320 case OptimizationAlgorithm::kMdlAdagradLight: {
321 add_state_variable("parameters");
322 add_state_variable("accumulators");
323 add_state_variable("weights");
324 add_state_variable("benefits");
325 break;
326 }
327 case OptimizationAlgorithm::kAdadelta: {
328 add_state_variable("parameters");
329 add_state_variable("accumulators");
330 add_state_variable("updates");
331 break;
332 }
333 case OptimizationAlgorithm::kProximalAdagrad: {
334 add_state_variable("parameters");
335 add_state_variable("accumulators");
336 break;
337 }
338 case OptimizationAlgorithm::kOnlineYogi: {
339 add_state_variable("parameters");
340 add_state_variable("vs");
341 add_state_variable("linears");
342 break;
343 }
344 case OptimizationAlgorithm::kProximalYogi: {
345 add_state_variable("parameters");
346 add_state_variable("v");
347 add_state_variable("m");
348 break;
349 }
350 case OptimizationAlgorithm::kFrequencyEstimator: {
351 add_state_variable("parameters");
352 add_state_variable("last_hit_step");
353 break;
354 }
355 case OptimizationAlgorithm::kUserDefinedProgram: {
356 add_state_variable("parameters");
357 int num_slots = -1;
358 TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
359 for (int i = 0; i < num_slots; ++i) {
360 add_state_variable(absl::StrCat("Slot_", i));
361 }
362 break;
363 }
364 case OptimizationAlgorithm::kAssign: {
365 add_state_variable("parameters");
366 break;
367 }
368 case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
369 return errors::InvalidArgument("No optimization algorithm specified");
370 }
371 }
372
373 // This needs to be last for compatibility.
374 if (use_gradient_accumulation) {
375 StateVariableSpecification gradient_acc;
376 gradient_acc.set_name("gradient_accumulators");
377 gradient_acc.mutable_fill_with_constant()->set_initial_value(
378 GradientAccumulatorInitialValue());
379 state_variables->push_back(std::move(gradient_acc));
380 }
381
382 if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
383 return errors::InvalidArgument(
384 "Optimization algorithm",
385 GetOptimizationAlgorithmName(params.parameters_case()),
386 "does not support gradient accumulation because it "
387 "already has too many other accumulators");
388 }
389 return OkStatus();
390}
391
392std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
393 return {
394 OptimizationAlgorithm::kAdagrad,
395 OptimizationAlgorithm::kAdagradMomentum,
396 OptimizationAlgorithm::kBoundedAdagrad,
397 OptimizationAlgorithm::kStochasticGradientDescent,
398 OptimizationAlgorithm::kFtrl,
399 OptimizationAlgorithm::kAdam,
400 OptimizationAlgorithm::kMomentum,
401 OptimizationAlgorithm::kRmsProp,
402 OptimizationAlgorithm::kCenteredRmsProp,
403 OptimizationAlgorithm::kMdlAdagradLight,
404 OptimizationAlgorithm::kAdadelta,
405 OptimizationAlgorithm::kProximalAdagrad,
406 OptimizationAlgorithm::kOnlineYogi,
407 OptimizationAlgorithm::kProximalYogi,
408 OptimizationAlgorithm::kFrequencyEstimator,
409 OptimizationAlgorithm::kUserDefinedProgram,
410 OptimizationAlgorithm::kAssign,
411 };
412}
413
414Status LoadOpShapeFunction::operator()(
415 shape_inference::InferenceContext* c) const {
416 int table_id;
417 TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
418 string table_name;
419 TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
420 // Exactly one must be non-default.
421 if ((table_id >= 0) == (!table_name.empty())) {
422 return errors::InvalidArgument(
423 "exactly one of table_id or table_name must be non-default");
424 }
425 int num_shards;
426 TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
427 int shard_id;
428 TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
429
430 // Verify shapes have rank 2 and are compatible when they are
431 // required to be valid.
432 shape_inference::ShapeHandle parameter_shape;
433 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &parameter_shape));
434 for (int j = 1; j < c->num_inputs(); ++j) {
435 shape_inference::ShapeHandle accumulator_j_shape;
436 TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
437 shape_inference::ShapeHandle merged;
438 TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
439 }
440
441 return OkStatus();
442}
443
444Status RetrieveOpShapeFunction::operator()(
445 shape_inference::InferenceContext* c) const {
446 int table_id;
447 TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
448 string table_name;
449 TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
450 // Exactly one must be non-default.
451 if ((table_id >= 0) == (!table_name.empty())) {
452 return errors::InvalidArgument(
453 "exactly one of table_id or table_name must be non-default");
454 }
455 int num_shards;
456 TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
457 int shard_id;
458 TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
459 for (int j = 0; j < c->num_outputs(); ++j) {
460 c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
461 2, c->UnknownDim())));
462 }
463 return OkStatus();
464}
465
466} // namespace tpu
467} // namespace tensorflow
468