1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_ |
18 | |
19 | #include <limits> |
20 | |
21 | #include "tensorflow/core/kernels/loss.h" |
22 | #include "tensorflow/core/lib/core/errors.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class SmoothHingeLossUpdater : public DualLossUpdater { |
28 | public: |
29 | // Computes the updated dual variable (corresponding) to a single example. The |
30 | // updated dual value maximizes the objective function of the dual |
31 | // optimization problem associated with smooth hinge loss. The computations |
32 | // are detailed in readme.md. |
33 | double ComputeUpdatedDual(const int num_partitions, const double label, |
34 | const double example_weight, |
35 | const double current_dual, const double wx, |
36 | const double weighted_example_norm) const final { |
37 | // Intuitively there are 3 cases: |
38 | // a. new optimal value of the dual variable falls within the admissible |
39 | // range [0, 1]. In this case we set new dual to this value. |
40 | // b. new optimal value is < 0. Then, because of convexity, the optimal |
41 | // valid value for new dual = 0 |
42 | // c. new optimal value > 1.0. Then new optimal value should be set to 1.0. |
43 | const double candidate_optimal_dual = |
44 | current_dual + |
45 | (label - wx - gamma * current_dual) / |
46 | (num_partitions * example_weight * weighted_example_norm + gamma); |
47 | if (label * candidate_optimal_dual < 0) { |
48 | return 0.0; |
49 | } |
50 | if (label * candidate_optimal_dual > 1.0) { |
51 | return label; |
52 | } |
53 | return candidate_optimal_dual; |
54 | } |
55 | |
56 | double ComputeDualLoss(const double current_dual, const double example_label, |
57 | const double example_weight) const final { |
58 | // For binary classification, there are 2 conjugate functions, one per |
59 | // label value (-1 and 1). |
60 | const double y_alpha = current_dual * example_label; // y \alpha |
61 | if (y_alpha < 0 || y_alpha > 1.0) { |
62 | return std::numeric_limits<double>::max(); |
63 | } |
64 | return (-y_alpha + 0.5 * gamma * current_dual * current_dual) * |
65 | example_weight; |
66 | } |
67 | |
68 | double ComputePrimalLoss(const double wx, const double example_label, |
69 | const double example_weight) const final { |
70 | const double y_wx = example_label * wx; |
71 | if (y_wx >= 1) return 0; |
72 | if (y_wx <= 1 - gamma) return (1 - y_wx - gamma / 2) * example_weight; |
73 | return (1 - y_wx) * (1 - y_wx) * example_weight * 0.5 / gamma; |
74 | } |
75 | |
76 | // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively |
77 | // as expected by smooth hinge loss. |
78 | Status ConvertLabel(float* const example_label) const final { |
79 | if (*example_label == 0.0) { |
80 | *example_label = -1; |
81 | return OkStatus(); |
82 | } |
83 | if (*example_label == 1.0) { |
84 | return OkStatus(); |
85 | } |
86 | return errors::InvalidArgument( |
87 | "Only labels of 0.0 or 1.0 are supported right now. " |
88 | "Found example with label: " , |
89 | *example_label); |
90 | } |
91 | |
92 | double PrimalLossDerivative(const double wx, const double label, |
93 | const double example_weight) const final { |
94 | if (label * wx >= 1) { |
95 | return 0; |
96 | } |
97 | if (label * wx <= 1 - gamma) { |
98 | return -label; |
99 | } |
100 | return (wx - label) / gamma; |
101 | } |
102 | |
103 | double SmoothnessConstant() const final { return gamma; } |
104 | |
105 | private: |
106 | // Smoothness constant of smooth hinge loss |
107 | // TODO(sibyl-Aix6ihai): expose this parameter |
108 | const double gamma = 1; |
109 | }; |
110 | |
111 | } // namespace tensorflow |
112 | |
113 | #endif // TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_ |
114 | // TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_ |
115 | |