1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_KERNELS_LOSS_H_ |
16 | #define TENSORFLOW_CORE_KERNELS_LOSS_H_ |
17 | |
18 | #include "tensorflow/core/lib/core/status.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | class DualLossUpdater { |
23 | public: |
24 | virtual ~DualLossUpdater() {} |
25 | |
26 | // Compute update dual (alpha), based on a single example. Various strategies |
27 | // can be employed here, like newton step and/or line search or approximate |
28 | // step that decreases the dual sub-optimality. |
29 | virtual double ComputeUpdatedDual( |
30 | const int num_loss_partitions, const double label, |
31 | const double example_weight, const double current_dual, const double wx, |
32 | const double weighted_example_norm) const = 0; |
33 | |
34 | // Compute dual loss based on the current dual (alpha), example label (y) |
35 | // and example weight (cost). |
36 | virtual double ComputeDualLoss(const double current_dual, |
37 | const double example_label, |
38 | const double example_weight) const = 0; |
39 | |
40 | // Compute the primal loss based on current estimate of log-odds(wx), |
41 | // example label (y) and example weight (cost). |
42 | virtual double ComputePrimalLoss(const double wx, const double example_label, |
43 | const double example_weight) const = 0; |
44 | |
45 | // Primal loss derivative used to compute the dual residue in AdaSDCA |
46 | virtual double PrimalLossDerivative(const double wx, |
47 | const double example_label, |
48 | const double example_weight) const = 0; |
49 | |
50 | // This is gamma such that the loss derivative is 1/gamma Lipschitz |
51 | virtual double SmoothnessConstant() const = 0; |
52 | |
53 | // Converts binary example labels from 0.0 or 1.0 to appropriate range for |
54 | // each loss function. |
55 | virtual Status ConvertLabel(float* const example_label) const = 0; |
56 | }; |
57 | |
58 | } // namespace tensorflow |
59 | #endif // TENSORFLOW_CORE_KERNELS_LOSS_H_ |
60 | |