1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/sdca_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include <stdint.h>
21
22#include <atomic>
23#include <limits>
24#include <memory>
25#include <new>
26#include <string>
27#include <vector>
28
29#include "absl/strings/str_format.h"
30#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31#include "tensorflow/core/framework/device_base.h"
32#include "tensorflow/core/framework/kernel_def_builder.h"
33#include "tensorflow/core/framework/op.h"
34#include "tensorflow/core/framework/op_def_builder.h"
35#include "tensorflow/core/framework/op_kernel.h"
36#include "tensorflow/core/framework/tensor.h"
37#include "tensorflow/core/framework/tensor_shape.h"
38#include "tensorflow/core/framework/tensor_types.h"
39#include "tensorflow/core/framework/types.h"
40#include "tensorflow/core/kernels/hinge-loss.h"
41#include "tensorflow/core/kernels/logistic-loss.h"
42#include "tensorflow/core/kernels/loss.h"
43#include "tensorflow/core/kernels/poisson-loss.h"
44#include "tensorflow/core/kernels/sdca_internal.h"
45#include "tensorflow/core/kernels/smooth-hinge-loss.h"
46#include "tensorflow/core/kernels/squared-loss.h"
47#include "tensorflow/core/lib/core/coding.h"
48#include "tensorflow/core/lib/core/errors.h"
49#include "tensorflow/core/lib/core/status.h"
50#include "tensorflow/core/lib/core/stringpiece.h"
51#include "tensorflow/core/lib/gtl/inlined_vector.h"
52#include "tensorflow/core/platform/fingerprint.h"
53#include "tensorflow/core/platform/macros.h"
54#include "tensorflow/core/platform/mutex.h"
55#include "tensorflow/core/platform/types.h"
56#include "tensorflow/core/util/work_sharder.h"
57
58namespace tensorflow {
59
60namespace {
61
62using sdca::Example;
63using sdca::Examples;
64using sdca::ExampleStatistics;
65using sdca::ModelWeights;
66using sdca::Regularizations;
67
68struct ComputeOptions {
69 explicit ComputeOptions(OpKernelConstruction* const context) {
70 string loss_type;
71 OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
72 if (loss_type == "logistic_loss") {
73 loss_updater.reset(new LogisticLossUpdater);
74 } else if (loss_type == "squared_loss") {
75 loss_updater.reset(new SquaredLossUpdater);
76 } else if (loss_type == "hinge_loss") {
77 loss_updater.reset(new HingeLossUpdater);
78 } else if (loss_type == "smooth_hinge_loss") {
79 loss_updater.reset(new SmoothHingeLossUpdater);
80 } else if (loss_type == "poisson_loss") {
81 loss_updater.reset(new PoissonLossUpdater);
82 } else {
83 OP_REQUIRES(
84 context, false,
85 errors::InvalidArgument("Unsupported loss type: ", loss_type));
86 }
87 auto s = context->GetAttr("adaptative", &adaptive);
88 if (!s.ok()) {
89 s = context->GetAttr("adaptive", &adaptive);
90 }
91 OP_REQUIRES_OK(context, s);
92 OP_REQUIRES_OK(
93 context, context->GetAttr("num_sparse_features", &num_sparse_features));
94 OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
95 &num_sparse_features_with_values));
96 OP_REQUIRES_OK(context,
97 context->GetAttr("num_dense_features", &num_dense_features));
98 OP_REQUIRES(
99 context, num_sparse_features + num_dense_features > 0,
100 errors::InvalidArgument("Requires at least one feature to train."));
101
102 OP_REQUIRES(context,
103 static_cast<int64_t>(num_sparse_features) +
104 static_cast<int64_t>(num_dense_features) <=
105 std::numeric_limits<int>::max(),
106 errors::InvalidArgument(absl::StrFormat(
107 "Too many feature groups: %d > %d",
108 static_cast<int64_t>(num_sparse_features) +
109 static_cast<int64_t>(num_dense_features),
110 std::numeric_limits<int>::max())));
111 OP_REQUIRES_OK(
112 context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
113 OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
114 &num_inner_iterations));
115 OP_REQUIRES_OK(context, regularizations.Initialize(context));
116 }
117
118 std::unique_ptr<DualLossUpdater> loss_updater;
119 int num_sparse_features = 0;
120 int num_sparse_features_with_values = 0;
121 int num_dense_features = 0;
122 int num_inner_iterations = 0;
123 int num_loss_partitions = 0;
124 bool adaptive = true;
125 Regularizations regularizations;
126};
127
128// TODO(shengx): The helper classes/methods are changed to support multiclass
129// SDCA, which lead to changes within this function. Need to revisit the
130// convergence once the multiclass SDCA is in.
131void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
132 ModelWeights model_weights;
133 OP_REQUIRES_OK(context, model_weights.Initialize(context));
134
135 Examples examples;
136 OP_REQUIRES_OK(
137 context,
138 examples.Initialize(context, model_weights, options.num_sparse_features,
139 options.num_sparse_features_with_values,
140 options.num_dense_features));
141
142 const Tensor* example_state_data_t;
143 OP_REQUIRES_OK(context,
144 context->input("example_state_data", &example_state_data_t));
145 TensorShape expected_example_state_shape({examples.num_examples(), 4});
146 OP_REQUIRES(context,
147 example_state_data_t->shape() == expected_example_state_shape,
148 errors::InvalidArgument(
149 "Expected shape ", expected_example_state_shape.DebugString(),
150 " for example_state_data, got ",
151 example_state_data_t->shape().DebugString()));
152
153 Tensor mutable_example_state_data_t(*example_state_data_t);
154 auto example_state_data = mutable_example_state_data_t.matrix<float>();
155 OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
156 mutable_example_state_data_t));
157
158 if (options.adaptive) {
159 OP_REQUIRES_OK(context,
160 examples.SampleAdaptiveProbabilities(
161 options.num_loss_partitions, options.regularizations,
162 model_weights, example_state_data, options.loss_updater,
163 /*num_weight_vectors =*/1));
164 } else {
165 examples.RandomShuffle();
166 }
167 struct {
168 mutex mu;
169 Status value TF_GUARDED_BY(mu);
170 } train_step_status;
171 std::atomic<std::int64_t> atomic_index(-1);
172 auto train_step = [&](const int64_t begin, const int64_t end) {
173 // The static_cast here is safe since begin and end can be at most
174 // num_examples which is an int.
175 for (int id = static_cast<int>(begin); id < end; ++id) {
176 const int64_t example_index = examples.sampled_index(++atomic_index);
177 const Example& example = examples.example(example_index);
178 const float dual = example_state_data(example_index, 0);
179 const float example_weight = example.example_weight();
180 float example_label = example.example_label();
181 const Status conversion_status =
182 options.loss_updater->ConvertLabel(&example_label);
183 if (!conversion_status.ok()) {
184 mutex_lock l(train_step_status.mu);
185 train_step_status.value = conversion_status;
186 // Return from this worker thread - the calling thread is
187 // responsible for checking context status and returning on error.
188 return;
189 }
190
191 // Compute wx, example norm weighted by regularization, dual loss,
192 // primal loss.
193 // For binary SDCA, num_weight_vectors should be one.
194 const ExampleStatistics example_statistics =
195 example.ComputeWxAndWeightedExampleNorm(
196 options.num_loss_partitions, model_weights,
197 options.regularizations, 1 /* num_weight_vectors */);
198
199 const double new_dual = options.loss_updater->ComputeUpdatedDual(
200 options.num_loss_partitions, example_label, example_weight, dual,
201 example_statistics.wx[0], example_statistics.normalized_squared_norm);
202
203 // Compute new weights.
204 const double normalized_bounded_dual_delta =
205 (new_dual - dual) * example_weight /
206 options.regularizations.symmetric_l2();
207 model_weights.UpdateDeltaWeights(
208 context->eigen_cpu_device(), example,
209 std::vector<double>{normalized_bounded_dual_delta});
210
211 // Update example data.
212 example_state_data(example_index, 0) = new_dual;
213 example_state_data(example_index, 1) =
214 options.loss_updater->ComputePrimalLoss(
215 example_statistics.prev_wx[0], example_label, example_weight);
216 example_state_data(example_index, 2) =
217 options.loss_updater->ComputeDualLoss(dual, example_label,
218 example_weight);
219 example_state_data(example_index, 3) = example_weight;
220 }
221 };
222 // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
223 // number of cpus, and cost per example.
224 const int64_t kCostPerUnit = examples.num_features();
225 const DeviceBase::CpuWorkerThreads& worker_threads =
226 *context->device()->tensorflow_cpu_worker_threads();
227
228 Shard(worker_threads.num_threads, worker_threads.workers,
229 examples.num_examples(), kCostPerUnit, train_step);
230 mutex_lock l(train_step_status.mu);
231 OP_REQUIRES_OK(context, train_step_status.value);
232}
233
234} // namespace
235
236class SdcaOptimizer : public OpKernel {
237 public:
238 explicit SdcaOptimizer(OpKernelConstruction* const context)
239 : OpKernel(context), options_(context) {}
240
241 void Compute(OpKernelContext* context) override {
242 DoCompute(options_, context);
243 }
244
245 private:
246 // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
247 // template the entire class to avoid the virtual table lookup penalty in
248 // the inner loop.
249 ComputeOptions options_;
250};
251REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
252 SdcaOptimizer);
253REGISTER_KERNEL_BUILDER(Name("SdcaOptimizerV2").Device(DEVICE_CPU),
254 SdcaOptimizer);
255
256class SdcaShrinkL1 : public OpKernel {
257 public:
258 explicit SdcaShrinkL1(OpKernelConstruction* const context)
259 : OpKernel(context) {
260 OP_REQUIRES_OK(context, regularizations_.Initialize(context));
261 }
262
263 void Compute(OpKernelContext* context) override {
264 OpMutableInputList weights_inputs;
265 OP_REQUIRES_OK(context,
266 context->mutable_input_list("weights", &weights_inputs));
267
268 auto do_work = [&](const int64_t begin, const int64_t end) {
269 for (int i = begin; i < end; ++i) {
270 auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
271 prox_w.device(context->eigen_cpu_device()) =
272 regularizations_.EigenShrinkVector(prox_w);
273 }
274 };
275
276 if (weights_inputs.size() > 0) {
277 int64_t num_weights = 0;
278 for (int i = 0; i < weights_inputs.size(); ++i) {
279 num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
280 }
281 // TODO(sibyl-Aix6ihai): Tune this value.
282 const int64_t kCostPerUnit = (num_weights * 50) / weights_inputs.size();
283 const DeviceBase::CpuWorkerThreads& worker_threads =
284 *context->device()->tensorflow_cpu_worker_threads();
285 Shard(worker_threads.num_threads, worker_threads.workers,
286 weights_inputs.size(), kCostPerUnit, do_work);
287 }
288 }
289
290 private:
291 Regularizations regularizations_;
292};
293REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
294
295// Computes platform independent, compact and unique (with very high
296// probability) representation of an example id. It shouldn't be put in
297// persistent storage, as its implementation may change in the future.
298//
299// The current probability of at least one collision for 1B example_ids is
300// approximately 10^-21 (ie 2^60 / 2^129).
301class SdcaFprint : public OpKernel {
302 public:
303 explicit SdcaFprint(OpKernelConstruction* const context)
304 : OpKernel(context) {}
305
306 void Compute(OpKernelContext* context) override {
307 const Tensor& input = context->input(0);
308 OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
309 errors::InvalidArgument("Input must be a vector, got shape ",
310 input.shape().DebugString()));
311 Tensor* out;
312 const int64_t num_elements = input.NumElements();
313 OP_REQUIRES_OK(context, context->allocate_output(
314 0, TensorShape({num_elements, 2}), &out));
315
316 const auto in_values = input.flat<tstring>();
317 auto out_values = out->matrix<int64_t>();
318
319 for (int64_t i = 0; i < num_elements; ++i) {
320 const Fprint128 fprint = Fingerprint128(in_values(i));
321 // Never return 0 or 1 as the first value of the hash to allow these to
322 // safely be used as sentinel values (e.g. dense hash table empty key).
323 out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
324 ? fprint.low64
325 : fprint.low64 + ~static_cast<uint64>(1);
326 out_values(i, 1) = fprint.high64;
327 }
328 }
329};
330REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
331
332} // namespace tensorflow
333