1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_
17#define TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_
18
19#define EIGEN_USE_THREADS
20
21#include <stddef.h>
22#include <algorithm>
23#include <cmath>
24#include <memory>
25#include <new>
26#include <unordered_map>
27#include <utility>
28#include <vector>
29
30#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31#include "tensorflow/core/framework/device_base.h"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/framework/tensor_shape.h"
35#include "tensorflow/core/framework/tensor_types.h"
36#include "tensorflow/core/framework/types.h"
37#include "tensorflow/core/kernels/loss.h"
38#include "tensorflow/core/lib/core/coding.h"
39#include "tensorflow/core/lib/core/errors.h"
40#include "tensorflow/core/lib/core/status.h"
41#include "tensorflow/core/lib/core/stringpiece.h"
42#include "tensorflow/core/lib/gtl/inlined_vector.h"
43#include "tensorflow/core/lib/random/distribution_sampler.h"
44#include "tensorflow/core/lib/strings/stringprintf.h"
45#include "tensorflow/core/util/guarded_philox_random.h"
46#include "tensorflow/core/util/work_sharder.h"
47
48namespace tensorflow {
49
50namespace sdca {
51
52// Statistics computed with input (ModelWeights, Example).
53struct ExampleStatistics {
54 // Logits for each class.
55 // For binary case, this should be a vector of length 1; while for multiclass
56 // case, this vector has the same length as the number of classes, where each
57 // value corresponds to one class.
58 // Use InlinedVector to avoid heap allocation for small number of classes.
59 gtl::InlinedVector<double, 1> wx;
60
61 // Logits for each class, using the previous weights.
62 gtl::InlinedVector<double, 1> prev_wx;
63
64 // Sum of squared feature values occurring in the example divided by
65 // L2 * sum(example_weights).
66 double normalized_squared_norm = 0;
67
68 // Num_weight_vectors equals to the number of classification classes in the
69 // multiclass case; while for binary case, it is 1.
70 ExampleStatistics(const int num_weight_vectors)
71 : wx(num_weight_vectors, 0.0), prev_wx(num_weight_vectors, 0.0) {}
72};
73
74class Regularizations {
75 public:
76 Regularizations() {}
77
78 // Initialize() must be called immediately after construction.
79 Status Initialize(OpKernelConstruction* const context) {
80 TF_RETURN_IF_ERROR(context->GetAttr("l1", &symmetric_l1_));
81 TF_RETURN_IF_ERROR(context->GetAttr("l2", &symmetric_l2_));
82 shrinkage_ = symmetric_l1_ / symmetric_l2_;
83 return OkStatus();
84 }
85
86 // Proximal SDCA shrinking for L1 regularization.
87 double Shrink(const double weight) const {
88 const double shrinked = std::max(std::abs(weight) - shrinkage_, 0.0);
89 if (shrinked > 0.0) {
90 return std::copysign(shrinked, weight);
91 }
92 return 0.0;
93 }
94
95 // Vectorized float variant of the above.
96 Eigen::Tensor<float, 1, Eigen::RowMajor> EigenShrinkVector(
97 const Eigen::Tensor<float, 1, Eigen::RowMajor> weights) const {
98 // Proximal step on the weights which is sign(w)*|w - shrinkage|+.
99 return weights.sign() * ((weights.abs() - weights.constant(shrinkage_))
100 .cwiseMax(weights.constant(0.0)));
101 }
102
103 // Matrix float variant of the above.
104 Eigen::Tensor<float, 2, Eigen::RowMajor> EigenShrinkMatrix(
105 const Eigen::Tensor<float, 2, Eigen::RowMajor> weights) const {
106 // Proximal step on the weights which is sign(w)*|w - shrinkage|+.
107 return weights.sign() * ((weights.abs() - weights.constant(shrinkage_))
108 .cwiseMax(weights.constant(0.0)));
109 }
110
111 float symmetric_l2() const { return symmetric_l2_; }
112
113 private:
114 float symmetric_l1_ = 0;
115 float symmetric_l2_ = 0;
116
117 // L1 divided by L2, pre-computed for use during weight shrinking.
118 double shrinkage_ = 0;
119
120 TF_DISALLOW_COPY_AND_ASSIGN(Regularizations);
121};
122
123class ModelWeights;
124
125// Struct describing a single example.
126class Example {
127 public:
128 // Compute matrix vector product between weights (a matrix) and features
129 // (a vector). This method also computes the normalized example norm used
130 // in SDCA update.
131 // For multiclass case, num_weight_vectors equals to the number of classes;
132 // while for binary case, it is 1.
133 const ExampleStatistics ComputeWxAndWeightedExampleNorm(
134 const int num_loss_partitions, const ModelWeights& model_weights,
135 const Regularizations& regularization,
136 const int num_weight_vectors) const;
137
138 float example_label() const { return example_label_; }
139
140 float example_weight() const { return example_weight_; }
141
142 double squared_norm() const { return squared_norm_; }
143
144 // Sparse features associated with the example.
145 // Indices and Values are the associated feature index, and values. Values
146 // can be optionally absent, in which we case we implicitly assume a value of
147 // 1.0f.
148 struct SparseFeatures {
149 std::unique_ptr<TTypes<const int64_t>::UnalignedConstVec> indices;
150 std::unique_ptr<TTypes<const float>::UnalignedConstVec>
151 values; // nullptr encodes optional.
152 };
153
154 // A dense vector which is a row-slice of the underlying matrix.
155 struct DenseVector {
156 // Returns a row slice from the matrix.
157 Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>> Row()
158 const {
159 return Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>>(
160 data_matrix.data() + row_index * data_matrix.dimension(1),
161 data_matrix.dimension(1));
162 }
163
164 // Returns a row slice as a 1 * F matrix, where F is the number of features.
165 Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor>>
166 RowAsMatrix() const {
167 return Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor>>(
168 data_matrix.data() + row_index * data_matrix.dimension(1), 1,
169 data_matrix.dimension(1));
170 }
171
172 const TTypes<float>::ConstMatrix data_matrix;
173 const int64_t row_index;
174 };
175
176 private:
177 std::vector<SparseFeatures> sparse_features_;
178 std::vector<std::unique_ptr<DenseVector>> dense_vectors_;
179
180 float example_label_ = 0;
181 float example_weight_ = 0;
182 double squared_norm_ = 0; // sum squared norm of the features.
183
184 // Examples fills Example in a multi-threaded way.
185 friend class Examples;
186
187 // ModelWeights use each example for model update w += \alpha * x_{i};
188 friend class ModelWeights;
189};
190
191// Weights related to features. For example, say you have two sets of sparse
192// features i.e. age bracket and country, then FeatureWeightsDenseStorage hold
193// the parameters for it. We keep track of the original weight passed in and the
194// delta weight which the optimizer learns in each call to the optimizer.
195class FeatureWeightsDenseStorage {
196 public:
197 FeatureWeightsDenseStorage(const TTypes<const float>::Matrix nominals,
198 TTypes<float>::Matrix deltas)
199 : nominals_(nominals), deltas_(deltas) {
200 CHECK_GT(deltas.rank(), 1);
201 }
202
203 // Check if a feature index is with-in the bounds.
204 bool IndexValid(const int64_t index) const {
205 return index >= 0 && index < deltas_.dimension(1);
206 }
207
208 // Nominals here are the original weight matrix.
209 TTypes<const float>::Matrix nominals() const { return nominals_; }
210
211 // Delta weights during mini-batch updates.
212 TTypes<float>::Matrix deltas() const { return deltas_; }
213
214 // Updates delta weights based on active dense features in the example and
215 // the corresponding dual residual.
216 void UpdateDenseDeltaWeights(
217 const Eigen::ThreadPoolDevice& device,
218 const Example::DenseVector& dense_vector,
219 const std::vector<double>& normalized_bounded_dual_delta);
220
221 private:
222 // The nominal value of the weight for a feature (indexed by its id).
223 const TTypes<const float>::Matrix nominals_;
224 // The accumulated delta weight for a feature (indexed by its id).
225 TTypes<float>::Matrix deltas_;
226};
227
228// Similar to FeatureWeightsDenseStorage, but the underlying weights are stored
229// in an unordered map.
230class FeatureWeightsSparseStorage {
231 public:
232 FeatureWeightsSparseStorage(const TTypes<const int64_t>::Vec indices,
233 const TTypes<const float>::Matrix nominals,
234 TTypes<float>::Matrix deltas)
235 : nominals_(nominals), deltas_(deltas) {
236 // Create a map from sparse index to the dense index of the underlying
237 // storage.
238 for (int64_t j = 0; j < indices.size(); ++j) {
239 indices_to_id_[indices(j)] = j;
240 }
241 }
242
243 // Check if a feature index exists.
244 bool IndexValid(const int64_t index) const {
245 return indices_to_id_.find(index) != indices_to_id_.end();
246 }
247
248 // Nominal value at a particular feature index and class label.
249 float nominals(const int class_id, const int64_t index) const {
250 auto it = indices_to_id_.find(index);
251 return nominals_(class_id, it->second);
252 }
253
254 // Delta weights during mini-batch updates.
255 float deltas(const int class_id, const int64_t index) const {
256 auto it = indices_to_id_.find(index);
257 return deltas_(class_id, it->second);
258 }
259
260 // Updates delta weights based on active sparse features in the example and
261 // the corresponding dual residual.
262 void UpdateSparseDeltaWeights(
263 const Eigen::ThreadPoolDevice& device,
264 const Example::SparseFeatures& sparse_features,
265 const std::vector<double>& normalized_bounded_dual_delta);
266
267 private:
268 // The nominal value of the weight for a feature (indexed by its id).
269 const TTypes<const float>::Matrix nominals_;
270 // The accumulated delta weight for a feature (indexed by its id).
271 TTypes<float>::Matrix deltas_;
272 // Map from feature index to an index to the dense vector.
273 std::unordered_map<int64_t, int64_t> indices_to_id_;
274};
275
276// Weights in the model, wraps both current weights, and the delta weights
277// for both sparse and dense features.
278class ModelWeights {
279 public:
280 ModelWeights() {}
281
282 bool SparseIndexValid(const int col, const int64_t index) const {
283 return sparse_weights_[col].IndexValid(index);
284 }
285
286 bool DenseIndexValid(const int col, const int64_t index) const {
287 return dense_weights_[col].IndexValid(index);
288 }
289
290 // Go through all the features present in the example, and update the
291 // weights based on the dual delta.
292 void UpdateDeltaWeights(
293 const Eigen::ThreadPoolDevice& device, const Example& example,
294 const std::vector<double>& normalized_bounded_dual_delta);
295
296 Status Initialize(OpKernelContext* const context);
297
298 const std::vector<FeatureWeightsSparseStorage>& sparse_weights() const {
299 return sparse_weights_;
300 }
301
302 const std::vector<FeatureWeightsDenseStorage>& dense_weights() const {
303 return dense_weights_;
304 }
305
306 private:
307 std::vector<FeatureWeightsSparseStorage> sparse_weights_;
308 std::vector<FeatureWeightsDenseStorage> dense_weights_;
309
310 TF_DISALLOW_COPY_AND_ASSIGN(ModelWeights);
311};
312
313// Examples contains all the training examples that SDCA uses for a mini-batch.
314class Examples {
315 public:
316 Examples() {}
317
318 // Returns the Example at |example_index|.
319 const Example& example(const int example_index) const {
320 return examples_.at(example_index);
321 }
322
323 int sampled_index(const int id) const { return sampled_index_[id]; }
324
325 // Adaptive SDCA in the current implementation only works for
326 // binary classification, where the input argument for num_weight_vectors
327 // is 1.
328 Status SampleAdaptiveProbabilities(
329 const int num_loss_partitions, const Regularizations& regularization,
330 const ModelWeights& model_weights,
331 const TTypes<float>::Matrix example_state_data,
332 const std::unique_ptr<DualLossUpdater>& loss_updater,
333 const int num_weight_vectors);
334
335 void RandomShuffle();
336
337 int num_examples() const { return examples_.size(); }
338
339 int num_features() const { return num_features_; }
340
341 // Initialize() must be called immediately after construction.
342 Status Initialize(OpKernelContext* const context, const ModelWeights& weights,
343 int num_sparse_features,
344 int num_sparse_features_with_values,
345 int num_dense_features);
346
347 private:
348 // Reads the input tensors, and builds the internal representation for sparse
349 // features per example. This function modifies the |examples| passed in
350 // to build the sparse representations.
351 static Status CreateSparseFeatureRepresentation(
352 const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples,
353 int num_sparse_features, const ModelWeights& weights,
354 const OpInputList& sparse_example_indices_inputs,
355 const OpInputList& sparse_feature_indices_inputs,
356 const OpInputList& sparse_feature_values_inputs,
357 std::vector<Example>* const examples);
358
359 // Reads the input tensors, and builds the internal representation for dense
360 // features per example. This function modifies the |examples| passed in
361 // to build the sparse representations.
362 static Status CreateDenseFeatureRepresentation(
363 const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples,
364 int num_dense_features, const ModelWeights& weights,
365 const OpInputList& dense_features_inputs,
366 std::vector<Example>* const examples);
367
368 // Computes squared example norm per example i.e |x|^2. This function modifies
369 // the |examples| passed in and adds the squared norm per example.
370 static Status ComputeSquaredNormPerExample(
371 const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples,
372 int num_sparse_features, int num_dense_features,
373 std::vector<Example>* const examples);
374
375 // All examples in the batch.
376 std::vector<Example> examples_;
377
378 // Adaptive sampling variables.
379 std::vector<float> probabilities_;
380 std::vector<int> sampled_index_;
381 std::vector<int> sampled_count_;
382
383 int num_features_ = 0;
384
385 TF_DISALLOW_COPY_AND_ASSIGN(Examples);
386};
387
388} // namespace sdca
389} // namespace tensorflow
390
391#endif // TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_
392