1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
17#define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
18
19#include <string>
20
21#include "absl/base/casts.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
25
26namespace tensorflow {
27namespace tpu {
28
29using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
30
31// Returns the name of the optimization algorithm.
32string GetOptimizationAlgorithmName(OptimizationAlgorithm alg);
33
34// Returns a user-friendly name for the optimization algorithm.
35string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg);
36
37// Returns all supported optimization algorithms.
38std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms();
39
40enum class GradientAccumulationSupport {
41 // Accumulation cannot be used with this optimizer.
42 kNotSupported,
43
44 // Accumulation is allowed and changes optimizer behavior.
45 kSupported,
46};
47
48// Returns the number of optimization parameter vectors used by the optimization
49// algorithm, excluding the weights themselves and assuming no gradient
50// accumulation.
51Status GetBaseAuxiliaryParameterCount(const OptimizationParameters &params,
52 int *count);
53
54// Returns whether (and how) an optimization algorithm supports gradient
55// accumulation.
56Status GetGradientAccumulationSupport(const OptimizationParameters &params,
57 GradientAccumulationSupport *support);
58
59// Returns whether both the given set of optimization parameters has gradient
60// accumulation turned on and that the algorithm used supports it or should
61// ignore that setting. Returns an error if gradient accumulation is enabled and
62// the algorithm does not support it.
63Status UseGradientAccumulation(const OptimizationParameters &params,
64 bool *use_gradient_accumulation);
65
66// Returns the parameter specifications for the optimization algorithm (the main
67// parameters first, followed by any auxiliary parameters such as Adagrad
68// accumulators).
69Status GetOptimizationAlgorithmStateVariables(
70 const OptimizationParameters &params,
71 std::vector<StateVariableSpecification> *state_variables);
72
73// Maximum value of auxiliary_parametery_count for any optimization algorithm.
74// This count is used by TPU embedding load/retrieve and needs to be independent
75// of any particular TPU version and hence, we take the maximum across all TPU
76// versions.
77static constexpr int kMaxAuxiliaryParameterCount = 7;
78
79// Fill value for gradient accumulators. This is a denormal so that it will be
80// flushed to zero on the current TPU platforms and needs to continue to have
81// the following properties in the future:
82//
83// 1. Does not have the same bit pattern as a zero and can be distinguished from
84// it using integer operations.
85// 2. Treated as zero by floating-point arithmetic operations (at least addition
86// and subtraction).
87// 3. Cannot be produced by any floating-point arithmetic operation, including
88// those involving itself.
89//
90// It does not need to compare equal or not equal to zero in floating point. We
91// need to use a non-zero value here because some optimization algorithms are
92// not no-ops on zero gradients, so we need to distinguish an accumulated
93// gradient of zero from one that has been cleared after its gradients have
94// already been applied to the parameters and accumulators.
95inline float GradientAccumulatorInitialValue() {
96 return absl::bit_cast<float, uint32>(1);
97}
98
99// Generic shape function for per-optimization-algorithm load ops.
100class LoadOpShapeFunction {
101 public:
102 // Computes resulting shape and does parameter checking.
103 Status operator()(shape_inference::InferenceContext *c) const;
104};
105
106// Generic shape function for per-optimization-algorithm retrieve ops.
107class RetrieveOpShapeFunction {
108 public:
109 // Computes resulting shape and does parameter checking.
110 Status operator()(shape_inference::InferenceContext *c) const;
111};
112
113} // namespace tpu
114} // namespace tensorflow
115
116#endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
117