1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/random_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/multinomial_op.h"
21
22#include <algorithm>
23#include <cmath>
24#include <memory>
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/kernels/stateless_random_ops.h"
32#include "tensorflow/core/lib/random/random_distributions.h"
33#include "tensorflow/core/lib/random/simple_philox.h"
34#include "tensorflow/core/util/guarded_philox_random.h"
35#include "tensorflow/core/util/work_sharder.h"
36
37namespace tensorflow {
38
39typedef Eigen::ThreadPoolDevice CPUDevice;
40typedef Eigen::GpuDevice GPUDevice;
41
42namespace functor {
43
44template <typename Device, typename T, typename OutputType>
45struct MultinomialFunctor {
46 void operator()(OpKernelContext* ctx, const Device& d,
47 typename TTypes<T>::ConstMatrix logits,
48 typename TTypes<float>::Flat noises,
49 typename TTypes<float>::Flat scores,
50 typename TTypes<float>::Flat scratch, int batch_size,
51 int num_classes, int num_samples,
52 const random::PhiloxRandom& gen,
53 typename TTypes<OutputType>::Matrix output);
54};
55
56#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
57extern template struct MultinomialFunctor<GPUDevice, Eigen::half, int32>;
58extern template struct MultinomialFunctor<GPUDevice, float, int32>;
59extern template struct MultinomialFunctor<GPUDevice, double, int32>;
60extern template struct MultinomialFunctor<GPUDevice, int32, int32>;
61extern template struct MultinomialFunctor<GPUDevice, int64_t, int32>;
62
63extern template struct MultinomialFunctor<GPUDevice, Eigen::half, int64_t>;
64extern template struct MultinomialFunctor<GPUDevice, float, int64_t>;
65extern template struct MultinomialFunctor<GPUDevice, double, int64_t>;
66extern template struct MultinomialFunctor<GPUDevice, int32, int64_t>;
67extern template struct MultinomialFunctor<GPUDevice, int64_t, int64_t>;
68#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
69
70template <typename T, typename OutputType>
71struct MultinomialFunctor<CPUDevice, T, OutputType> {
72 void operator()(OpKernelContext* ctx, const CPUDevice& d,
73 typename TTypes<T>::ConstMatrix logits,
74 typename TTypes<float>::Flat /* noises */,
75 typename TTypes<float>::Flat /* scores */,
76 typename TTypes<float>::Flat /* scratch */, int batch_size,
77 int num_classes, int num_samples,
78 const random::PhiloxRandom& gen,
79 typename TTypes<OutputType>::Matrix output) {
80 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
81
82 // The implementation only parallelizes by batch.
83 //
84 // This takes O(BatchSize * NumSamples * log(NumClasses) + NumClasses) CPU
85 // time.
86 auto DoWork = [ctx, num_samples, num_classes, &gen, &output, &logits](
87 int64_t start_row, int64_t limit_row) {
88 // Capturing "gen" by-value would only make a copy for the _shared_
89 // lambda. Since we want to let each worker have its own copy, we pass
90 // "gen" by reference and explicitly do a copy assignment here.
91 random::PhiloxRandom gen_copy = gen;
92 // Skip takes units of 128 bits. +3 is so rounding doesn't lead to
93 // us using the same state in different batches.
94 gen_copy.Skip(start_row * (num_samples + 3) / 4);
95 random::SimplePhilox simple_philox(&gen_copy);
96
97 Tensor cdf_tensor;
98 OP_REQUIRES_OK(ctx,
99 ctx->allocate_temp(DT_DOUBLE, TensorShape({num_classes}),
100 &cdf_tensor));
101 auto cdf = cdf_tensor.flat<double>();
102 for (int64_t b = start_row; b < limit_row; ++b) {
103 const auto* logits_row = &logits(b, 0);
104
105 // Takes an along-class maximum (for numerical stability).
106 T max = std::numeric_limits<T>::lowest();
107 for (int64_t j = 0; j < num_classes; ++j) {
108 if (Eigen::numext::isfinite(logits_row[j])) {
109 max = std::max(max, logits_row[j]);
110 }
111 }
112 const double max_logit = static_cast<double>(max);
113
114 // Precompute cumulative probability distribution across classes.
115 // Note: This isn't normalized.
116 cdf = (logits.template chip<0>(b).template cast<double>() - max_logit)
117 .exp();
118 double running_total = 0;
119 for (int64_t j = 0; j < num_classes; ++j) {
120 if (Eigen::numext::isfinite(logits_row[j])) {
121 running_total += cdf(j);
122 }
123 cdf(j) = running_total;
124 }
125 // Generate each sample.
126 const double* cdf_begin = cdf.data();
127 const double* cdf_end = cdf.data() + num_classes;
128 for (int64_t j = 0; j < num_samples; ++j) {
129 const double to_find = simple_philox.RandDouble() * running_total;
130 auto found_iter = std::upper_bound(cdf_begin, cdf_end, to_find);
131 output(b, j) = std::distance(cdf_begin, found_iter);
132 }
133 }
134 };
135 // Incredibly rough estimate of clock cycles for DoWork();
136 const int64_t cost =
137 50 * (num_samples * std::log(num_classes) / std::log(2) + num_classes);
138 Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost,
139 DoWork);
140 }
141};
142
143} // namespace functor
144
145namespace {
146
147// Samples from a multinomial distribution.
148template <typename Device, typename T, typename OutputType>
149class MultinomialOp : public OpKernel {
150 public:
151 explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {}
152
153 void DoCompute(OpKernelContext* ctx, const Tensor& logits_t,
154 const Tensor& num_samples_t, GuardedPhiloxRandom* generator) {
155 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_t.shape()),
156 errors::InvalidArgument("logits should be a matrix, got shape ",
157 logits_t.shape().DebugString()));
158 OP_REQUIRES(
159 ctx, TensorShapeUtils::IsScalar(num_samples_t.shape()),
160 errors::InvalidArgument("num_samples should be a scalar, got shape ",
161 num_samples_t.shape().DebugString()));
162
163 const int num_samples = num_samples_t.scalar<int>()();
164 OP_REQUIRES(ctx, num_samples >= 0,
165 errors::InvalidArgument(
166 "num_samples should be nonnegative, got ", num_samples));
167
168 for (int i = 0; i < 2; i++) {
169 const int64_t dim = logits_t.dim_size(i);
170 OP_REQUIRES(ctx, static_cast<int>(dim) == dim,
171 errors::InvalidArgument(
172 "logits.shape = ", logits_t.shape().DebugString(),
173 " too large for int"));
174 }
175 const int batch_size = static_cast<int>(logits_t.dim_size(0));
176 const int num_classes = static_cast<int>(logits_t.dim_size(1));
177 OP_REQUIRES(ctx, num_classes > 0,
178 errors::InvalidArgument("num_classes should be positive, got ",
179 num_classes));
180
181 Tensor* samples_t;
182 OP_REQUIRES_OK(
183 ctx, ctx->allocate_output(0, TensorShape({batch_size, num_samples}),
184 &samples_t));
185
186 // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU.
187 if (samples_t->NumElements() > 0) {
188 Tensor noises, scores, scratch; // Scratch space only used for GPU.
189 if (std::is_same<Device, GPUDevice>::value) {
190 OP_REQUIRES_OK(
191 ctx,
192 ctx->allocate_temp(
193 DT_FLOAT, TensorShape({batch_size, num_samples, num_classes}),
194 &noises));
195 OP_REQUIRES_OK(
196 ctx,
197 ctx->allocate_temp(
198 DT_FLOAT, TensorShape({batch_size, num_samples, num_classes}),
199 &scores));
200 OP_REQUIRES_OK(
201 ctx,
202 ctx->allocate_temp(DT_FLOAT, TensorShape({batch_size, num_samples}),
203 &scratch));
204 }
205
206 int num_samples_ceil_4 = (num_samples + 3) / 4 * 4;
207 // CPU generates doubles = 2 samples per number.
208 if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2;
209 auto rng =
210 generator->ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
211 functor::MultinomialFunctor<Device, T, OutputType>()(
212 ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(),
213 noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(),
214 batch_size, num_classes, num_samples, rng,
215 samples_t->matrix<OutputType>());
216 }
217 }
218};
219
220template <typename Device, typename T, typename OutputType>
221class StatefulMultinomialOp : public MultinomialOp<Device, T, OutputType> {
222 public:
223 explicit StatefulMultinomialOp(OpKernelConstruction* ctx)
224 : MultinomialOp<Device, T, OutputType>(ctx) {
225 OP_REQUIRES_OK(ctx, generator_.Init(ctx));
226 }
227
228 void Compute(OpKernelContext* ctx) override {
229 const Tensor& logits_t = ctx->input(0);
230 const Tensor& num_samples_t = ctx->input(1);
231 this->DoCompute(ctx, logits_t, num_samples_t, &generator_);
232 }
233
234 private:
235 GuardedPhiloxRandom generator_;
236};
237
238// TODO(b/77906027): Add a TPU implementation.
239#define REGISTER(TYPE) \
240 REGISTER_KERNEL_BUILDER(Name("Multinomial") \
241 .Device(DEVICE_CPU) \
242 .TypeConstraint<TYPE>("T") \
243 .TypeConstraint("output_dtype", DT_INT32), \
244 StatefulMultinomialOp<CPUDevice, TYPE, int32>); \
245 REGISTER_KERNEL_BUILDER(Name("Multinomial") \
246 .Device(DEVICE_CPU) \
247 .TypeConstraint<TYPE>("T") \
248 .TypeConstraint("output_dtype", DT_INT64), \
249 StatefulMultinomialOp<CPUDevice, TYPE, int64>);
250
251TF_CALL_half(REGISTER);
252TF_CALL_float(REGISTER);
253TF_CALL_double(REGISTER);
254#undef REGISTER
255
256#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
257#define REGISTER(TYPE) \
258 REGISTER_KERNEL_BUILDER(Name("Multinomial") \
259 .Device(DEVICE_GPU) \
260 .HostMemory("num_samples") \
261 .TypeConstraint<TYPE>("T") \
262 .TypeConstraint("output_dtype", DT_INT32), \
263 StatefulMultinomialOp<GPUDevice, TYPE, int32>) \
264 REGISTER_KERNEL_BUILDER(Name("Multinomial") \
265 .Device(DEVICE_GPU) \
266 .HostMemory("num_samples") \
267 .TypeConstraint<TYPE>("T") \
268 .TypeConstraint("output_dtype", DT_INT64), \
269 StatefulMultinomialOp<GPUDevice, TYPE, int64>)
270
271TF_CALL_half(REGISTER);
272TF_CALL_float(REGISTER);
273TF_CALL_double(REGISTER);
274#undef REGISTER
275
276#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
277
278template <typename Device, typename T, typename OutputType>
279class StatelessMultinomialOp : public MultinomialOp<Device, T, OutputType> {
280 public:
281 explicit StatelessMultinomialOp(OpKernelConstruction* ctx)
282 : MultinomialOp<Device, T, OutputType>(ctx) {}
283
284 void Compute(OpKernelContext* ctx) override {
285 const Tensor& logits_t = ctx->input(0);
286 const Tensor& num_samples_t = ctx->input(1);
287
288 const Tensor& seed_t = ctx->input(2);
289 OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
290 errors::InvalidArgument("seed must have shape [2], not ",
291 seed_t.shape().DebugString()));
292
293 random::PhiloxRandom::Key key;
294 random::PhiloxRandom::ResultType counter;
295 OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
296
297 GuardedPhiloxRandom generator;
298 generator.Init(counter, key);
299
300 this->DoCompute(ctx, logits_t, num_samples_t, &generator);
301 }
302
303 private:
304 GuardedPhiloxRandom generator_;
305};
306
307#define REGISTER(TYPE) \
308 REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
309 .Device(DEVICE_CPU) \
310 .TypeConstraint<TYPE>("T") \
311 .TypeConstraint("output_dtype", DT_INT32), \
312 StatelessMultinomialOp<CPUDevice, TYPE, int32>); \
313 REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
314 .Device(DEVICE_CPU) \
315 .TypeConstraint<TYPE>("T") \
316 .TypeConstraint("output_dtype", DT_INT64), \
317 StatelessMultinomialOp<CPUDevice, TYPE, int64>);
318
319TF_CALL_half(REGISTER);
320TF_CALL_float(REGISTER);
321TF_CALL_double(REGISTER);
322#undef REGISTER
323
324#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
325#define REGISTER(TYPE) \
326 REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
327 .Device(DEVICE_GPU) \
328 .HostMemory("num_samples") \
329 .HostMemory("seed") \
330 .TypeConstraint<TYPE>("T") \
331 .TypeConstraint("output_dtype", DT_INT32), \
332 StatelessMultinomialOp<GPUDevice, TYPE, int32>) \
333 REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial") \
334 .Device(DEVICE_GPU) \
335 .HostMemory("num_samples") \
336 .HostMemory("seed") \
337 .TypeConstraint<TYPE>("T") \
338 .TypeConstraint("output_dtype", DT_INT64), \
339 StatelessMultinomialOp<GPUDevice, TYPE, int64>)
340
341TF_CALL_half(REGISTER);
342TF_CALL_float(REGISTER);
343TF_CALL_double(REGISTER);
344#undef REGISTER
345
346#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
347
348} // end namespace
349
350} // end namespace tensorflow
351