1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <algorithm>
16#include <cmath>
17#include <cstdint>
18#include <random>
19
20#include "tensorflow/core/lib/random/philox_random.h"
21#include "tensorflow/core/lib/random/random_distributions_utils.h"
22#include "tensorflow/lite/c/builtin_op_data.h"
23#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace random {
30
31namespace {
32
33using Generator = ::tensorflow::random::PhiloxRandom;
34
35enum RandomType { kRandomUniform, kRandomStandardNormal, kMultinomial };
36
37struct OpData {
38 Generator rng;
39};
40
41// Initialize the OpData based on the seed and seed2 values.
42void InitializeOpData(TfLiteNode* node) {
43 static std::mt19937_64* seed_generator = []() {
44 std::random_device device("/dev/urandom");
45 return new std::mt19937_64(device());
46 }();
47 auto* params = static_cast<TfLiteRandomParams*>(node->builtin_data);
48 OpData* data = reinterpret_cast<OpData*>(node->user_data);
49 int64_t seed = params->seed;
50 int64_t seed2 = params->seed2;
51 if (seed == 0 && seed2 == 0) {
52 // If both seeds are unspecified, generate non-deterministic random numbers.
53 seed = (*seed_generator)();
54 seed2 = (*seed_generator)();
55 }
56 Generator rng(seed, seed2);
57 data->rng = rng;
58}
59
60// Generates random numbers following a uniform distribution.
61// Source: third_party/tensorflow/core/kernels/random_op.cc
62void GenerateRandomUniformNumbers(
63 Generator& rng, float* buffer, size_t buffer_size) {
64 size_t current_size = 0;
65 size_t rng_size = Generator::kResultElementCount;
66
67 while (current_size < buffer_size) {
68 typename Generator::ResultType samples = rng();
69 const int rng_net_size = std::min(rng_size, buffer_size - current_size);
70 for (int i = 0; i < rng_net_size; i++) {
71 buffer[current_size + i] = tensorflow::random::Uint32ToFloat(samples[i]);
72 }
73 current_size += rng_net_size;
74 }
75}
76
77// Generates random numbers following a standard normal distribution.
78// Source: third_party/tensorflow/core/kernels/random_op.cc
79void GenerateRandomStandardNormalNumbers(
80 Generator& rng, float* buffer, size_t buffer_size) {
81 size_t current_size = 0;
82 size_t rng_size = Generator::kResultElementCount;
83
84 while (current_size < buffer_size) {
85 typename Generator::ResultType samples = rng();
86 const int rng_net_size = std::min(rng_size, buffer_size - current_size);
87 for (int i = 0; i < rng_net_size; i += 2) {
88 tensorflow::random::BoxMullerFloat(samples[i], samples[i + 1],
89 &buffer[current_size + i],
90 &buffer[current_size + i + 1]);
91 }
92 current_size += rng_net_size;
93 }
94}
95
96// Generates random numbers following a multinomial distribution.
97// Source: third_party/tensorflow/core/kernels/multinomial_op.cc
98template <typename IntType>
99void GenerateMultinomialNumbers(Generator& rng, int batch_size,
100 const float* logits, size_t logits_size,
101 IntType* output, size_t num_samples) {
102 // Skip a large fixed number of samples in the rng (random number generator)
103 // for each op invoke to ensure that the output is always unique. (Make a copy
104 // of the rng before skipping samples to use it in the current op invoke)
105 // Context: This feature (to skip fixed samples) was added in TF as some
106 // versions of the Multinomial op draw an unknown number of samples from the
107 // rng. Though the TFLite version below only draws a fixed number of samples,
108 // we still need to keep this feature to maintain parity with the TF op.
109 Generator rng_copy = rng;
110 rng.Skip(batch_size * ((num_samples + 3) / 4 * 4) * 2 *
111 256); // Round to a multiple of 4, 2x is for CPU and 256 is a
112 // conservative multiplier
113
114 // Variables to store intermediate results between batches.
115 typename Generator::ResultType rng_results;
116 int used_rng_results_index = Generator::kResultElementCount;
117 typename Generator::ResultElementType x0, x1;
118
119 // Iterate over all batches to compute the outputs.
120 for (int batch = 0; batch < batch_size; ++batch) {
121 const float* logits_row = logits + batch * logits_size;
122 IntType* output_row = output + batch * num_samples;
123
124 // Compute the maximum logit.
125 float max = std::numeric_limits<float>::lowest();
126 for (size_t i = 0; i < logits_size; i++) {
127 if (std::isfinite(logits_row[i])) {
128 max = std::max(max, logits_row[i]);
129 }
130 }
131 const double max_logit = static_cast<double>(max);
132
133 // Compute the (unnormalized) cumulative probability distribution.
134 // For numerical stability (as the exponential function grows very fast),
135 // subtract the maximum logit. Though you can subtract any value without
136 // changing the output, we use the maximum logit for convenience.
137 std::vector<double> cdf(logits_size);
138 double cumulative_total = 0.0f;
139 for (size_t i = 0; i < logits_size; i++) {
140 if (std::isfinite(logits_row[i])) {
141 cumulative_total += exp(logits_row[i] - max_logit);
142 }
143 cdf[i] = cumulative_total;
144 }
145
146 // Generate random categorical numbers and populate the output.
147 for (int64_t j = 0; j < num_samples; ++j) {
148 if (used_rng_results_index == Generator::kResultElementCount) {
149 rng_results = rng_copy();
150 used_rng_results_index = 0;
151 }
152 x0 = rng_results[used_rng_results_index];
153 x1 = rng_results[used_rng_results_index + 1];
154 used_rng_results_index += 2;
155 const double to_find =
156 (tensorflow::random::Uint64ToDouble(x0, x1) * cumulative_total);
157 auto found_iter = std::upper_bound(cdf.begin(), cdf.end(), to_find);
158 output_row[j] = std::distance(cdf.begin(), found_iter);
159 }
160 }
161}
162
163} // namespace
164
165void* Init(TfLiteContext* context, const char* buffer, size_t length) {
166 return new OpData();
167}
168
169void Free(TfLiteContext* context, void* buffer) {
170 delete reinterpret_cast<OpData*>(buffer);
171}
172
173TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
174 // Validate number of inputs and outputs
175 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
176 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
177
178 // 'shape' is a 1-D int array
179 const TfLiteTensor* shape;
180 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &shape));
181 TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32);
182 TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1);
183
184 // Initialize the random number generator
185 InitializeOpData(node);
186
187 TfLiteTensor* output = GetOutput(context, node, 0);
188 if (!IsConstantTensor(shape)) {
189 SetTensorToDynamic(output);
190 return kTfLiteOk;
191 }
192 TfLiteIntArray* output_shape;
193 TF_LITE_ENSURE_OK(context,
194 GetOutputShapeFromInput(context, shape, &output_shape));
195 return context->ResizeTensor(context, output, output_shape);
196}
197
198TfLiteStatus PrepareMultinomial(TfLiteContext* context, TfLiteNode* node) {
199 // Validate number of inputs and outputs
200 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
201 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
202
203 // 'logits' is a 2-D input float matrix with shape [batch_size, num_classes]
204 const TfLiteTensor* logits;
205 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &logits));
206 TF_LITE_ENSURE(context, logits->type == kTfLiteFloat32);
207
208 // 'num_samples' is a 0-D input int scalar
209 const TfLiteTensor* num_samples;
210 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &num_samples));
211 TF_LITE_ENSURE_EQ(context, num_samples->type, kTfLiteInt32);
212
213 // Initialize the random number generator
214 InitializeOpData(node);
215
216 TfLiteTensor* output = GetOutput(context, node, 0);
217 if (!IsConstantTensor(logits) || !IsConstantTensor(num_samples)) {
218 SetTensorToDynamic(output);
219 return kTfLiteOk;
220 }
221
222 // 'output' is a 2-D int64 matrix with shape [batch_size, num_samples]
223 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(2);
224 output_shape->data[0] = SizeOfDimension(logits, 0); // batch_size
225 output_shape->data[1] = *num_samples->data.i32; // num_samples
226 return context->ResizeTensor(context, output, output_shape);
227}
228
229TfLiteStatus EvalRandomType(
230 TfLiteContext* context, TfLiteNode* node, RandomType random_type) {
231 TfLiteTensor* output = GetOutput(context, node, 0);
232 OpData* data = reinterpret_cast<OpData*>(node->user_data);
233 const size_t output_size = NumElements(output);
234 switch (random_type) {
235 case kRandomUniform:
236 GenerateRandomUniformNumbers(
237 data->rng, GetTensorData<float>(output), output_size);
238 break;
239 case kRandomStandardNormal:
240 GenerateRandomStandardNormalNumbers(
241 data->rng, GetTensorData<float>(output), output_size);
242 break;
243 default:
244 return kTfLiteError;
245 }
246 return kTfLiteOk;
247}
248
249template <RandomType rtype>
250TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
251 TfLiteTensor* output = GetOutput(context, node, 0);
252
253 if (IsDynamicTensor(output)) {
254 const TfLiteTensor* shape = GetInput(context, node, 0);
255 TfLiteIntArray* output_shape;
256 TF_LITE_ENSURE_OK(context,
257 GetOutputShapeFromInput(context, shape, &output_shape));
258 context->ResizeTensor(context, output, output_shape);
259 }
260
261 switch (output->type) {
262 case kTfLiteFloat32:
263 EvalRandomType(context, node, rtype);
264 break;
265 default:
266 TF_LITE_KERNEL_LOG(
267 context, "Unsupported output datatype for %s op: %s",
268 rtype == kRandomUniform? "RandomUniform": "RandomStandardNormal",
269 TfLiteTypeGetName(output->type));
270 return kTfLiteError;
271 }
272 return kTfLiteOk;
273}
274
275TfLiteStatus EvalMultinomial(TfLiteContext* context, TfLiteNode* node) {
276 OpData* data = reinterpret_cast<OpData*>(node->user_data);
277
278 // 'logits' is a 2-D float matrix with shape [batch_size, num_classes]
279 const TfLiteTensor* logits_tensor = GetInput(context, node, 0);
280 TF_LITE_ENSURE_EQ(context, NumDimensions(logits_tensor), 2);
281 const float* logits = GetTensorData<float>(logits_tensor);
282 const int batch_size = SizeOfDimension(logits_tensor, 0);
283 const int num_classes = SizeOfDimension(logits_tensor, 1);
284 TF_LITE_ENSURE(context, num_classes > 0);
285
286 // 'num_samples' is an int scalar
287 const TfLiteTensor* num_samples_tensor = GetInput(context, node, 1);
288 TF_LITE_ENSURE_EQ(context, NumDimensions(num_samples_tensor), 0);
289 const int num_samples = *num_samples_tensor->data.i32;
290 TF_LITE_ENSURE(context, num_samples >= 0);
291
292 TfLiteTensor* output_tensor = GetOutput(context, node, 0);
293 if (IsDynamicTensor(output_tensor)) {
294 // 'output' is a 2-D int64 matrix with shape [batch_size, num_samples]
295 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(2);
296 output_shape->data[0] = batch_size;
297 output_shape->data[1] = num_samples;
298 TF_LITE_ENSURE_OK(
299 context, context->ResizeTensor(context, output_tensor, output_shape));
300 }
301
302 switch (output_tensor->type) {
303 case kTfLiteInt64:
304 GenerateMultinomialNumbers<int64_t>(
305 data->rng, batch_size, logits, num_classes,
306 GetTensorData<int64_t>(output_tensor), num_samples);
307 break;
308 case kTfLiteInt32:
309 GenerateMultinomialNumbers<int32_t>(
310 data->rng, batch_size, logits, num_classes,
311 GetTensorData<int32_t>(output_tensor), num_samples);
312 break;
313 default:
314 TF_LITE_KERNEL_LOG(context,
315 "Unsupported output datatype for Multinomial op: %s",
316 TfLiteTypeGetName(output_tensor->type));
317 return kTfLiteError;
318 }
319 return kTfLiteOk;
320}
321
322} // namespace random
323
324TfLiteRegistration* Register_RANDOM_UNIFORM() {
325 static TfLiteRegistration r = {random::Init, random::Free, random::Prepare,
326 random::Eval<random::kRandomUniform>};
327 return &r;
328}
329
330TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL() {
331 static TfLiteRegistration r = {random::Init, random::Free, random::Prepare,
332 random::Eval<random::kRandomStandardNormal>};
333 return &r;
334}
335
336TfLiteRegistration* Register_MULTINOMIAL() {
337 static TfLiteRegistration r = {random::Init, random::Free,
338 random::PrepareMultinomial,
339 random::EvalMultinomial};
340 return &r;
341}
342
343} // namespace builtin
344} // namespace ops
345} // namespace tflite
346