1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/sdca_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <stdint.h> |
21 | |
22 | #include <atomic> |
23 | #include <limits> |
24 | #include <memory> |
25 | #include <new> |
26 | #include <string> |
27 | #include <vector> |
28 | |
29 | #include "absl/strings/str_format.h" |
30 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
31 | #include "tensorflow/core/framework/device_base.h" |
32 | #include "tensorflow/core/framework/kernel_def_builder.h" |
33 | #include "tensorflow/core/framework/op.h" |
34 | #include "tensorflow/core/framework/op_def_builder.h" |
35 | #include "tensorflow/core/framework/op_kernel.h" |
36 | #include "tensorflow/core/framework/tensor.h" |
37 | #include "tensorflow/core/framework/tensor_shape.h" |
38 | #include "tensorflow/core/framework/tensor_types.h" |
39 | #include "tensorflow/core/framework/types.h" |
40 | #include "tensorflow/core/kernels/hinge-loss.h" |
41 | #include "tensorflow/core/kernels/logistic-loss.h" |
42 | #include "tensorflow/core/kernels/loss.h" |
43 | #include "tensorflow/core/kernels/poisson-loss.h" |
44 | #include "tensorflow/core/kernels/sdca_internal.h" |
45 | #include "tensorflow/core/kernels/smooth-hinge-loss.h" |
46 | #include "tensorflow/core/kernels/squared-loss.h" |
47 | #include "tensorflow/core/lib/core/coding.h" |
48 | #include "tensorflow/core/lib/core/errors.h" |
49 | #include "tensorflow/core/lib/core/status.h" |
50 | #include "tensorflow/core/lib/core/stringpiece.h" |
51 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
52 | #include "tensorflow/core/platform/fingerprint.h" |
53 | #include "tensorflow/core/platform/macros.h" |
54 | #include "tensorflow/core/platform/mutex.h" |
55 | #include "tensorflow/core/platform/types.h" |
56 | #include "tensorflow/core/util/work_sharder.h" |
57 | |
58 | namespace tensorflow { |
59 | |
60 | namespace { |
61 | |
62 | using sdca::Example; |
63 | using sdca::Examples; |
64 | using sdca::ExampleStatistics; |
65 | using sdca::ModelWeights; |
66 | using sdca::Regularizations; |
67 | |
68 | struct ComputeOptions { |
69 | explicit ComputeOptions(OpKernelConstruction* const context) { |
70 | string loss_type; |
71 | OP_REQUIRES_OK(context, context->GetAttr("loss_type" , &loss_type)); |
72 | if (loss_type == "logistic_loss" ) { |
73 | loss_updater.reset(new LogisticLossUpdater); |
74 | } else if (loss_type == "squared_loss" ) { |
75 | loss_updater.reset(new SquaredLossUpdater); |
76 | } else if (loss_type == "hinge_loss" ) { |
77 | loss_updater.reset(new HingeLossUpdater); |
78 | } else if (loss_type == "smooth_hinge_loss" ) { |
79 | loss_updater.reset(new SmoothHingeLossUpdater); |
80 | } else if (loss_type == "poisson_loss" ) { |
81 | loss_updater.reset(new PoissonLossUpdater); |
82 | } else { |
83 | OP_REQUIRES( |
84 | context, false, |
85 | errors::InvalidArgument("Unsupported loss type: " , loss_type)); |
86 | } |
87 | auto s = context->GetAttr("adaptative" , &adaptive); |
88 | if (!s.ok()) { |
89 | s = context->GetAttr("adaptive" , &adaptive); |
90 | } |
91 | OP_REQUIRES_OK(context, s); |
92 | OP_REQUIRES_OK( |
93 | context, context->GetAttr("num_sparse_features" , &num_sparse_features)); |
94 | OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values" , |
95 | &num_sparse_features_with_values)); |
96 | OP_REQUIRES_OK(context, |
97 | context->GetAttr("num_dense_features" , &num_dense_features)); |
98 | OP_REQUIRES( |
99 | context, num_sparse_features + num_dense_features > 0, |
100 | errors::InvalidArgument("Requires at least one feature to train." )); |
101 | |
102 | OP_REQUIRES(context, |
103 | static_cast<int64_t>(num_sparse_features) + |
104 | static_cast<int64_t>(num_dense_features) <= |
105 | std::numeric_limits<int>::max(), |
106 | errors::InvalidArgument(absl::StrFormat( |
107 | "Too many feature groups: %d > %d" , |
108 | static_cast<int64_t>(num_sparse_features) + |
109 | static_cast<int64_t>(num_dense_features), |
110 | std::numeric_limits<int>::max()))); |
111 | OP_REQUIRES_OK( |
112 | context, context->GetAttr("num_loss_partitions" , &num_loss_partitions)); |
113 | OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations" , |
114 | &num_inner_iterations)); |
115 | OP_REQUIRES_OK(context, regularizations.Initialize(context)); |
116 | } |
117 | |
118 | std::unique_ptr<DualLossUpdater> loss_updater; |
119 | int num_sparse_features = 0; |
120 | int num_sparse_features_with_values = 0; |
121 | int num_dense_features = 0; |
122 | int num_inner_iterations = 0; |
123 | int num_loss_partitions = 0; |
124 | bool adaptive = true; |
125 | Regularizations regularizations; |
126 | }; |
127 | |
128 | // TODO(shengx): The helper classes/methods are changed to support multiclass |
129 | // SDCA, which lead to changes within this function. Need to revisit the |
130 | // convergence once the multiclass SDCA is in. |
131 | void DoCompute(const ComputeOptions& options, OpKernelContext* const context) { |
132 | ModelWeights model_weights; |
133 | OP_REQUIRES_OK(context, model_weights.Initialize(context)); |
134 | |
135 | Examples examples; |
136 | OP_REQUIRES_OK( |
137 | context, |
138 | examples.Initialize(context, model_weights, options.num_sparse_features, |
139 | options.num_sparse_features_with_values, |
140 | options.num_dense_features)); |
141 | |
142 | const Tensor* example_state_data_t; |
143 | OP_REQUIRES_OK(context, |
144 | context->input("example_state_data" , &example_state_data_t)); |
145 | TensorShape expected_example_state_shape({examples.num_examples(), 4}); |
146 | OP_REQUIRES(context, |
147 | example_state_data_t->shape() == expected_example_state_shape, |
148 | errors::InvalidArgument( |
149 | "Expected shape " , expected_example_state_shape.DebugString(), |
150 | " for example_state_data, got " , |
151 | example_state_data_t->shape().DebugString())); |
152 | |
153 | Tensor mutable_example_state_data_t(*example_state_data_t); |
154 | auto example_state_data = mutable_example_state_data_t.matrix<float>(); |
155 | OP_REQUIRES_OK(context, context->set_output("out_example_state_data" , |
156 | mutable_example_state_data_t)); |
157 | |
158 | if (options.adaptive) { |
159 | OP_REQUIRES_OK(context, |
160 | examples.SampleAdaptiveProbabilities( |
161 | options.num_loss_partitions, options.regularizations, |
162 | model_weights, example_state_data, options.loss_updater, |
163 | /*num_weight_vectors =*/1)); |
164 | } else { |
165 | examples.RandomShuffle(); |
166 | } |
167 | struct { |
168 | mutex mu; |
169 | Status value TF_GUARDED_BY(mu); |
170 | } train_step_status; |
171 | std::atomic<std::int64_t> atomic_index(-1); |
172 | auto train_step = [&](const int64_t begin, const int64_t end) { |
173 | // The static_cast here is safe since begin and end can be at most |
174 | // num_examples which is an int. |
175 | for (int id = static_cast<int>(begin); id < end; ++id) { |
176 | const int64_t example_index = examples.sampled_index(++atomic_index); |
177 | const Example& example = examples.example(example_index); |
178 | const float dual = example_state_data(example_index, 0); |
179 | const float example_weight = example.example_weight(); |
180 | float example_label = example.example_label(); |
181 | const Status conversion_status = |
182 | options.loss_updater->ConvertLabel(&example_label); |
183 | if (!conversion_status.ok()) { |
184 | mutex_lock l(train_step_status.mu); |
185 | train_step_status.value = conversion_status; |
186 | // Return from this worker thread - the calling thread is |
187 | // responsible for checking context status and returning on error. |
188 | return; |
189 | } |
190 | |
191 | // Compute wx, example norm weighted by regularization, dual loss, |
192 | // primal loss. |
193 | // For binary SDCA, num_weight_vectors should be one. |
194 | const ExampleStatistics example_statistics = |
195 | example.ComputeWxAndWeightedExampleNorm( |
196 | options.num_loss_partitions, model_weights, |
197 | options.regularizations, 1 /* num_weight_vectors */); |
198 | |
199 | const double new_dual = options.loss_updater->ComputeUpdatedDual( |
200 | options.num_loss_partitions, example_label, example_weight, dual, |
201 | example_statistics.wx[0], example_statistics.normalized_squared_norm); |
202 | |
203 | // Compute new weights. |
204 | const double normalized_bounded_dual_delta = |
205 | (new_dual - dual) * example_weight / |
206 | options.regularizations.symmetric_l2(); |
207 | model_weights.UpdateDeltaWeights( |
208 | context->eigen_cpu_device(), example, |
209 | std::vector<double>{normalized_bounded_dual_delta}); |
210 | |
211 | // Update example data. |
212 | example_state_data(example_index, 0) = new_dual; |
213 | example_state_data(example_index, 1) = |
214 | options.loss_updater->ComputePrimalLoss( |
215 | example_statistics.prev_wx[0], example_label, example_weight); |
216 | example_state_data(example_index, 2) = |
217 | options.loss_updater->ComputeDualLoss(dual, example_label, |
218 | example_weight); |
219 | example_state_data(example_index, 3) = example_weight; |
220 | } |
221 | }; |
222 | // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data, |
223 | // number of cpus, and cost per example. |
224 | const int64_t kCostPerUnit = examples.num_features(); |
225 | const DeviceBase::CpuWorkerThreads& worker_threads = |
226 | *context->device()->tensorflow_cpu_worker_threads(); |
227 | |
228 | Shard(worker_threads.num_threads, worker_threads.workers, |
229 | examples.num_examples(), kCostPerUnit, train_step); |
230 | mutex_lock l(train_step_status.mu); |
231 | OP_REQUIRES_OK(context, train_step_status.value); |
232 | } |
233 | |
234 | } // namespace |
235 | |
236 | class SdcaOptimizer : public OpKernel { |
237 | public: |
238 | explicit SdcaOptimizer(OpKernelConstruction* const context) |
239 | : OpKernel(context), options_(context) {} |
240 | |
241 | void Compute(OpKernelContext* context) override { |
242 | DoCompute(options_, context); |
243 | } |
244 | |
245 | private: |
246 | // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and |
247 | // template the entire class to avoid the virtual table lookup penalty in |
248 | // the inner loop. |
249 | ComputeOptions options_; |
250 | }; |
251 | REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer" ).Device(DEVICE_CPU), |
252 | SdcaOptimizer); |
253 | REGISTER_KERNEL_BUILDER(Name("SdcaOptimizerV2" ).Device(DEVICE_CPU), |
254 | SdcaOptimizer); |
255 | |
256 | class SdcaShrinkL1 : public OpKernel { |
257 | public: |
258 | explicit SdcaShrinkL1(OpKernelConstruction* const context) |
259 | : OpKernel(context) { |
260 | OP_REQUIRES_OK(context, regularizations_.Initialize(context)); |
261 | } |
262 | |
263 | void Compute(OpKernelContext* context) override { |
264 | OpMutableInputList weights_inputs; |
265 | OP_REQUIRES_OK(context, |
266 | context->mutable_input_list("weights" , &weights_inputs)); |
267 | |
268 | auto do_work = [&](const int64_t begin, const int64_t end) { |
269 | for (int i = begin; i < end; ++i) { |
270 | auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>(); |
271 | prox_w.device(context->eigen_cpu_device()) = |
272 | regularizations_.EigenShrinkVector(prox_w); |
273 | } |
274 | }; |
275 | |
276 | if (weights_inputs.size() > 0) { |
277 | int64_t num_weights = 0; |
278 | for (int i = 0; i < weights_inputs.size(); ++i) { |
279 | num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements(); |
280 | } |
281 | // TODO(sibyl-Aix6ihai): Tune this value. |
282 | const int64_t kCostPerUnit = (num_weights * 50) / weights_inputs.size(); |
283 | const DeviceBase::CpuWorkerThreads& worker_threads = |
284 | *context->device()->tensorflow_cpu_worker_threads(); |
285 | Shard(worker_threads.num_threads, worker_threads.workers, |
286 | weights_inputs.size(), kCostPerUnit, do_work); |
287 | } |
288 | } |
289 | |
290 | private: |
291 | Regularizations regularizations_; |
292 | }; |
293 | REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1" ).Device(DEVICE_CPU), SdcaShrinkL1); |
294 | |
295 | // Computes platform independent, compact and unique (with very high |
296 | // probability) representation of an example id. It shouldn't be put in |
297 | // persistent storage, as its implementation may change in the future. |
298 | // |
299 | // The current probability of at least one collision for 1B example_ids is |
300 | // approximately 10^-21 (ie 2^60 / 2^129). |
301 | class SdcaFprint : public OpKernel { |
302 | public: |
303 | explicit SdcaFprint(OpKernelConstruction* const context) |
304 | : OpKernel(context) {} |
305 | |
306 | void Compute(OpKernelContext* context) override { |
307 | const Tensor& input = context->input(0); |
308 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), |
309 | errors::InvalidArgument("Input must be a vector, got shape " , |
310 | input.shape().DebugString())); |
311 | Tensor* out; |
312 | const int64_t num_elements = input.NumElements(); |
313 | OP_REQUIRES_OK(context, context->allocate_output( |
314 | 0, TensorShape({num_elements, 2}), &out)); |
315 | |
316 | const auto in_values = input.flat<tstring>(); |
317 | auto out_values = out->matrix<int64_t>(); |
318 | |
319 | for (int64_t i = 0; i < num_elements; ++i) { |
320 | const Fprint128 fprint = Fingerprint128(in_values(i)); |
321 | // Never return 0 or 1 as the first value of the hash to allow these to |
322 | // safely be used as sentinel values (e.g. dense hash table empty key). |
323 | out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2) |
324 | ? fprint.low64 |
325 | : fprint.low64 + ~static_cast<uint64>(1); |
326 | out_values(i, 1) = fprint.high64; |
327 | } |
328 | } |
329 | }; |
330 | REGISTER_KERNEL_BUILDER(Name("SdcaFprint" ).Device(DEVICE_CPU), SdcaFprint); |
331 | |
332 | } // namespace tensorflow |
333 | |