1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cmath>
16
17#include "tensorflow/core/framework/bounds_check.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/platform/threadpool.h"
22
23namespace {
24
25using ::int64_t;
26using tensorflow::int32;
27
28// The # of ops estimated for the isotonic regression solver is the size of the
29// array multiplied by this constant. This is used by the thread pool executor
30// when deciding how many threads to use.
31constexpr int kCostMultiplier = 100;
32
33// In separable chain-constrained problems, i.e., those of the form
34//
35// min_{y_1 >= y_2 >= ... >= y_n} \sum_{i=1}^n h_i(y_i)
36//
37// for any set of convex functions h_i, of particular importance are contiguous
38// segments of coordinates, which this class represents. The interval is assumed
39// to be half-closed and equal to [col_start(), col_limit()).
40class Segment {
41 public:
42 // Creates the [col_index, col_index+1).
43 explicit Segment(int col_index)
44 : col_start_(col_index), col_limit_(col_index + 1) {}
45
46 // Returns the number of points in the segment.
47 int num_points() const { return col_limit_ - col_start_; }
48
49 // Merge another segment into this one.
50 void merge_with(const Segment& other) {
51 col_start_ = std::min(col_start_, other.col_start());
52 col_limit_ = std::max(col_limit_, other.col_limit());
53 }
54
55 int col_start() const { return col_start_; }
56
57 int col_limit() const { return col_limit_; }
58
59 private:
60 int col_start_;
61 int col_limit_;
62};
63
64// If we can solve for each segment {j, j+1, ..., j+m} the interval problem
65//
66// argmin_y \sum_{i=j}^{j+m} h_i(y),
67//
68// we can use such an oracle to solve the general problem. The following class
69// implements such an oracle for the case when h_i is the squared (l2) loss,
70// or formally h_i(y) = (y - x_i)^2, where x_i is the i-th input.
71//
72// TODO(josipd): We know how and can extend this to other functions if needed.
73template <typename T>
74class L2PavaSegment : public Segment {
75 public:
76 L2PavaSegment(T y, int col_index)
77 : Segment(col_index), y_sum_(y), minimum_(y) {}
78
79 void merge_with(const L2PavaSegment& other) {
80 Segment::merge_with(other);
81 y_sum_ += other.y_sum_;
82 minimum_ = y_sum_ / static_cast<T>(num_points());
83 }
84
85 T minimum() const { return minimum_; }
86
87 private:
88 T y_sum_; // The sum of the inputs within the segment.
89 T minimum_; // The minimum, cached to avoid expensive divisions.
90};
91
92// Solve one of the problems in the batch (the row_index'th one) using the
93// pool-adjacent violators algorithm (PAVA).
94//
95// The PAVA algorithm goes back to
96//
97// Nonmetric Multidimensional Scaling: A numerical method
98// Kruskal, J. B. (1964), Psychometrika (1964)
99//
100// For a more recent analysis, please refer to
101//
102// Active set algorithms for isotonic regression; a unifying framework
103// Best, Michael J., and Nilotpal Chakravarti
104// Mathematical Programming 47.1-3 (1990)
105//
106// Intuitively, the algorithm splits the inputs into blocks (starting from
107// singleton ones), and then whenever there are two consecutive blocks whose
108// minima violate the inequality constraint, they are merged. The solution is
109// then block-wise constant, each block equal to the corresponding minimum.
110//
111// The tensors should be two dimensional, and the segment objects should
112// support the minimum() and merge_with() methods.
113template <typename SegmentType, typename FloatTensor, typename IntTensor>
114void solve_pava(const std::function<SegmentType(int, int)>& make_segment,
115 FloatTensor* solution, IntTensor* segments, int row_index) {
116 const size_t n = solution->dimensions()[1];
117 std::vector<SegmentType> pools;
118 pools.reserve(n);
119
120 for (size_t col_index = 0; col_index < n; ++col_index) {
121 pools.push_back(make_segment(row_index, col_index));
122
123 // While the last two pools are decreasing, merge them.
124 while (pools.size() > 1 &&
125 pools.rbegin()->minimum() > (pools.rbegin() + 1)->minimum()) {
126 (pools.rbegin() + 1)->merge_with(*pools.rbegin());
127 pools.pop_back();
128 }
129 }
130
131 int segment_id = 0;
132 for (const auto& pool : pools) {
133 const auto pool_minimum = pool.minimum();
134 // The matrices are row major, so we can scan the memory linearly.
135 auto* solution_ptr = &(*solution)(row_index, pool.col_start());
136 auto* segments_ptr = &(*segments)(row_index, pool.col_start());
137 for (int i = pool.col_start(); i < pool.col_limit(); ++i) {
138 *solution_ptr++ = pool_minimum;
139 *segments_ptr++ = segment_id;
140 }
141 ++segment_id;
142 }
143}
144
145// Solve a batch of problems using the pool-adjacent violators algorithm.
146// The problems are solved in parallel using tensorflow's thread pool.
147template <typename SegmentType, typename FloatTensor, typename IntTensor>
148void solve_pava_batch(const std::function<SegmentType(int, int)>& make_segment,
149 FloatTensor* solution, IntTensor* segments,
150 tensorflow::OpKernelContext* context) {
151 const int batch_size = solution->dimensions()[0];
152 const int problem_size = solution->dimensions()[1];
153
154 auto thread_pool =
155 context->device()->tensorflow_cpu_worker_threads()->workers;
156
157 thread_pool->ParallelFor(
158 batch_size, kCostMultiplier * problem_size,
159 [&make_segment, &solution, &segments](int64_t row_start,
160 int64_t row_limit) {
161 // Casting to int is safe, as we do boundary checks in `Compute`.
162 for (int row_index = static_cast<int>(row_start);
163 row_index < static_cast<int>(row_limit); ++row_index) {
164 solve_pava(make_segment, solution, segments, row_index);
165 }
166 });
167}
168
169} // namespace
170
171template <typename Tin, typename Tout>
172class IsotonicRegressionOp : public tensorflow::OpKernel {
173 public:
174 explicit IsotonicRegressionOp(tensorflow::OpKernelConstruction* context)
175 : tensorflow::OpKernel(context) {}
176
177 void Compute(tensorflow::OpKernelContext* context) override {
178 // Grab the input tensor.
179 const tensorflow::Tensor& input_tensor = context->input(0);
180 const auto input = input_tensor.flat_inner_dims<Tin, 2>();
181 int int_max = std::numeric_limits<int32>::max();
182 OP_REQUIRES(context,
183 tensorflow::FastBoundsCheck(input.dimensions()[0], int_max) &&
184 tensorflow::FastBoundsCheck(input.dimensions()[1], int_max),
185 tensorflow::errors::InvalidArgument("Tensor too large"));
186
187 // Create the output tensor holding the minimizers.
188 const auto shape = input_tensor.shape();
189 tensorflow::Tensor* output_tensor = nullptr;
190 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
191 {0}, 0, shape, &output_tensor));
192 auto output = output_tensor->flat_inner_dims<Tout, 2>();
193
194 // Create the output tensor holidng the segment memberships.
195 tensorflow::Tensor* segments_tensor = nullptr;
196 OP_REQUIRES_OK(context,
197 context->allocate_output(1, shape, &segments_tensor));
198 auto segments = segments_tensor->flat_inner_dims<int>();
199
200 auto make_l2_segment = [&input](int row_index, int col_index) {
201 return L2PavaSegment<Tout>(input(row_index, col_index), col_index);
202 };
203 solve_pava_batch<L2PavaSegment<Tout>>(make_l2_segment, &output, &segments,
204 context);
205 }
206};
207
208#define REGISTER_CPU_KERNEL(Tin, Tout) \
209 REGISTER_KERNEL_BUILDER(Name("IsotonicRegression") \
210 .Device(tensorflow::DEVICE_CPU) \
211 .TypeConstraint<Tin>("T") \
212 .TypeConstraint<Tout>("output_dtype"), \
213 IsotonicRegressionOp<Tin, Tout>);
214
215// Float types have the same input and output.
216#define REGISTER_CPU_SAME_KERNEL(T) REGISTER_CPU_KERNEL(T, T)
217TF_CALL_FLOAT_TYPES(REGISTER_CPU_SAME_KERNEL);
218
219// 8 and 16 bit integers get converted to 32 bit floats.
220#define REGISTER_CPU_KERNEL_FLOAT(Tin) REGISTER_CPU_KERNEL(Tin, float)
221TF_CALL_int16(REGISTER_CPU_KERNEL_FLOAT);
222TF_CALL_int8(REGISTER_CPU_KERNEL_FLOAT);
223
224// 32 and 64 bit integers get converted to 64 bit floats.
225#define REGISTER_CPU_KERNEL_DOUBLE(Tin) REGISTER_CPU_KERNEL(Tin, double)
226TF_CALL_int64(REGISTER_CPU_KERNEL_DOUBLE);
227TF_CALL_int32(REGISTER_CPU_KERNEL_DOUBLE);
228