1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/framework/rng_alg.h"
19#include "tensorflow/core/framework/tensor_util.h"
20#include "tensorflow/core/kernels/fill_functor.h"
21#include "tensorflow/core/kernels/random_op_cpu.h"
22#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
23#include "tensorflow/core/kernels/training_op_helpers.h"
24#include "tensorflow/core/lib/random/random.h"
25#include "tensorflow/core/platform/errors.h"
26#include "tensorflow/core/platform/statusor.h"
27
28namespace tensorflow {
29
30namespace functor {
31
32template <typename Distribution>
33struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
34 void operator()(OpKernelContext* ctx, const CPUDevice& device,
35 Distribution dist, UpdateVariableAndFill_Philox_Arg* arg,
36 typename Distribution::ResultElementType* output_data)
37 TF_UNLOCK_FUNCTION() {
38 int64_t output_size = arg->output_size;
39 int64_t alg_tag_skip = arg->alg_tag_skip;
40 ScopedUnlockUnrefVar* state_var_guard = arg->state_var_guard;
41 Tensor* state_tensor = arg->state_tensor;
42
43 auto state_tensor_flat = state_tensor->flat<StateElementType>();
44 auto state_data = state_tensor_flat.data();
45 // Delegates to PhiloxRandom to do the actual increasing.
46 auto philox = GetPhiloxRandomFromMem(state_data + alg_tag_skip);
47 UpdateMemWithPhiloxRandom(philox, output_size, state_data + alg_tag_skip);
48 // No longer needs the lock.
49 state_var_guard->Release();
50 functor::FillPhiloxRandom<CPUDevice, Distribution>()(
51 ctx, device, /*key=*/nullptr, /*counter=*/nullptr, philox, output_data,
52 output_size, dist);
53 }
54};
55
56} // end namespace functor
57
58Status CheckState(const Tensor& state) {
59 if (state.dtype() != STATE_ELEMENT_DTYPE) {
60 return errors::InvalidArgument("dtype of RNG state variable must be ",
61 DataTypeString(STATE_ELEMENT_DTYPE),
62 ", not ", DataTypeString(state.dtype()));
63 }
64 if (state.dims() != 1) {
65 return errors::InvalidArgument(
66 "RNG state must have one and only one dimension, not ", state.dims());
67 }
68 return OkStatus();
69}
70
71Status CheckPhiloxState(const Tensor& state, int64_t alg_tag_skip = 0) {
72 static_assert(std::is_same<StateElementType, int64_t>::value,
73 "StateElementType must be int64");
74 static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
75 "PhiloxRandom::ResultElementType must be uint32");
76 auto min_size = alg_tag_skip + PHILOX_MIN_STATE_SIZE;
77 if (state.NumElements() < min_size) {
78 return errors::InvalidArgument(
79 "For the Philox algorithm, the size of state"
80 " must be at least ",
81 min_size, "; got ", state.NumElements());
82 }
83 return OkStatus();
84}
85
86template <typename AlgEnumType>
87StatusOr<AlgEnumType> GetAlgId(OpKernelContext* ctx, int input_idx) {
88 AlgEnumType alg_id;
89 TF_RETURN_IF_ERROR(GetScalar(ctx->input(input_idx), input_idx, &alg_id));
90 return alg_id;
91}
92
93template <typename AlgEnumType>
94StatusOr<ConcreteRngAlgorithm> ResolveAlg(AlgEnumType alg_id) {
95 switch (alg_id) {
96 case RNG_ALG_PHILOX:
97 return ConcreteRngAlgorithm::RNG_ALG_PHILOX;
98 case RNG_ALG_THREEFRY:
99 return ConcreteRngAlgorithm::RNG_ALG_THREEFRY;
100 case RNG_ALG_AUTO_SELECT:
101 // On non-XLA kernels, we pick Philox as the auto-selected algorithm.
102 return ConcreteRngAlgorithm::RNG_ALG_PHILOX;
103 default:
104 return errors::InvalidArgument("Unsupported algorithm id: ", alg_id);
105 }
106}
107
108template <typename AlgEnumType>
109StatusOr<ConcreteRngAlgorithm> GetAlg(OpKernelContext* ctx, int input_idx) {
110 TF_ASSIGN_OR_RETURN(auto alg_id, GetAlgId<AlgEnumType>(ctx, input_idx));
111 return ResolveAlg(alg_id);
112}
113
114template <typename Device, typename Distribution>
115Status UpdateVariableAndFill(
116 OpKernelContext* ctx, Distribution dist, int state_input_idx,
117 bool read_alg_from_state, ConcreteRngAlgorithm alg, int64_t output_size,
118 typename Distribution::ResultElementType* output_data) {
119 Var* var = nullptr;
120 TF_RETURN_IF_ERROR(
121 LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
122 // Use `ScopedUnlockUnrefVar` here instead of `mutex_lock` and `ScopedUnref`
123 // because the former supports early releasing which is needed by
124 // `UpdateVariableAndFill_Philox<CPU>` to avoid holding the lock while
125 // filling.
126 ScopedUnlockUnrefVar state_var_guard(var);
127 Tensor* var_tensor = var->tensor();
128 TF_RETURN_IF_ERROR(CheckState(*var_tensor));
129 auto var_tensor_flat = var_tensor->flat<StateElementType>();
130 int64_t alg_tag_skip = 0;
131 if (read_alg_from_state) {
132 alg_tag_skip = 1;
133 if (var_tensor_flat.size() < 1) {
134 return errors::InvalidArgument("Size of tensor must be at least 1");
135 }
136 auto alg_id = var_tensor_flat(0);
137 TF_ASSIGN_OR_RETURN(alg, ResolveAlg(alg_id));
138 }
139 switch (alg) {
140 case ConcreteRngAlgorithm::RNG_ALG_PHILOX:
141 TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip));
142 TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, StateElementType>(
143 ctx, var_tensor, var->copy_on_read_mode.load()));
144
145 UpdateVariableAndFill_Philox_Arg arg;
146 arg.output_size = output_size;
147 arg.alg_tag_skip = alg_tag_skip;
148 arg.state_var_guard = &state_var_guard;
149 arg.state_tensor = var_tensor;
150 functor::UpdateVariableAndFill_Philox<Device, Distribution>()(
151 ctx, ctx->eigen_device<Device>(), dist, &arg, output_data);
152 return OkStatus();
153 case ConcreteRngAlgorithm::RNG_ALG_THREEFRY:
154 return errors::Unimplemented(
155 "Non-XLA devices don't support the ThreeFry algorithm.");
156 }
157 return errors::Internal(
158 "This point shouldn't have been reached because the above switch should "
159 "have handled all algorithms.");
160}
161
162// Precondition: input(0) is an existing resource.
163template <typename Device, class Distribution>
164void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist,
165 int state_input_idx, int shape_input_idx,
166 bool read_alg_from_state, ConcreteRngAlgorithm alg) {
167 using T = typename Distribution::ResultElementType;
168 const Tensor& shape_t = ctx->input(shape_input_idx);
169 TensorShape shape;
170 OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
171 Tensor* output;
172 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
173 auto output_flat = output->flat<T>();
174 OP_REQUIRES_OK(ctx, UpdateVariableAndFill<Device>(
175 ctx, dist, state_input_idx, read_alg_from_state, alg,
176 output_flat.size(), output_flat.data()));
177}
178
179template <typename Device, class Distribution>
180class StatefulRandomOp : public OpKernel {
181 public:
182 explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
183
184 void Compute(OpKernelContext* ctx) override {
185 StatefulRandomCompute<Device>(
186 ctx, Distribution(), 0, 1, true,
187 ConcreteRngAlgorithm::RNG_ALG_PHILOX /*dummy*/);
188 }
189};
190
191template <typename T>
192Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
193 auto dtype = DataTypeToEnum<T>::v();
194 if (tensor.dims() != 0) {
195 return errors::InvalidArgument("input ", std::to_string(input_idx),
196 " (0-based) must have shape [], not ",
197 tensor.shape().DebugString());
198 }
199 if (tensor.dtype() != dtype) {
200 return errors::InvalidArgument("dtype of input ", std::to_string(input_idx),
201 " (0-based) must be ", DataTypeString(dtype),
202 ", not ", DataTypeString(tensor.dtype()));
203 }
204 *result = tensor.flat<T>()(0);
205 return OkStatus();
206}
207
208template <typename Device, class Distribution>
209class StatefulRandomOpV2 : public OpKernel {
210 public:
211 explicit StatefulRandomOpV2(OpKernelConstruction* ctx) : OpKernel(ctx) {}
212
213 void Compute(OpKernelContext* ctx) override {
214 OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<int64_t>(ctx, 1));
215 StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
216 /*shape_input_idx=*/2,
217 /*read_alg_from_state=*/false, alg);
218 }
219};
220
221template <typename Device, class IntType>
222class StatefulUniformIntOp : public OpKernel {
223 public:
224 explicit StatefulUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
225
226 void Compute(OpKernelContext* ctx) override {
227 OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<int64_t>(ctx, 1));
228 const Tensor& minval = ctx->input(3);
229 const Tensor& maxval = ctx->input(4);
230 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
231 errors::InvalidArgument("minval must be 0-D, got shape ",
232 minval.shape().DebugString()));
233 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
234 errors::InvalidArgument("maxval must be 0-D, got shape ",
235 maxval.shape().DebugString()));
236
237 // Verify that minval < maxval. This check intentionally happens after the
238 // early exit for empty output. Zero impossible things are fine.
239 IntType lo = minval.scalar<IntType>()();
240 IntType hi = maxval.scalar<IntType>()();
241 OP_REQUIRES(
242 ctx, lo < hi,
243 errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
244
245 // Build distribution
246 typedef random::UniformDistribution<random::PhiloxRandom, IntType>
247 Distribution;
248 Distribution dist(lo, hi);
249
250 StatefulRandomCompute<Device>(ctx, dist, /*state_input_idx=*/0,
251 /*shape_input_idx=*/2,
252 /*read_alg_from_state=*/false, alg);
253 }
254};
255
256template <typename Device, class IntType>
257class StatefulUniformFullIntOp : public OpKernel {
258 public:
259 explicit StatefulUniformFullIntOp(OpKernelConstruction* ctx)
260 : OpKernel(ctx) {}
261
262 void Compute(OpKernelContext* ctx) override {
263 OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<int64_t>(ctx, 1));
264 StatefulRandomCompute<Device>(
265 ctx,
266 random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
267 /*state_input_idx=*/0, /*shape_input_idx=*/2,
268 /*read_alg_from_state=*/false, alg);
269 }
270};
271
272namespace functor {
273
274template <>
275struct RngSkip_Philox<CPUDevice> {
276 void operator()(const CPUDevice& device, const StateElementType* in_data,
277 uint64 delta, StateElementType* out_data) {
278 // Delegates to PhiloxRandom to do the actual increasing.
279 auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
280 UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
281 }
282};
283
284} // end namespace functor
285
286template <typename Device, typename AlgEnumType = int64_t,
287 typename DeltaType = int64_t, bool read_old_value = false>
288class RngSkipOp : public OpKernel {
289 public:
290 explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
291
292 void Compute(OpKernelContext* ctx) override {
293 auto state_input_idx = 0;
294 auto alg_input_idx = 1;
295 auto delta_input_idx = 2;
296 // GetAlg will treat RNG_ALG_AUTO_SELECT as RNG_ALG_PHILOX.
297 OP_REQUIRES_VALUE(auto alg, ctx, GetAlg<AlgEnumType>(ctx, alg_input_idx));
298 DeltaType delta_;
299 OP_REQUIRES_OK(
300 ctx, GetScalar(ctx->input(delta_input_idx), delta_input_idx, &delta_));
301 uint64 delta = static_cast<uint64>(delta_);
302 Var* var = nullptr;
303 OP_REQUIRES_OK(
304 ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
305 ScopedUnlockUnrefVar state_var_guard(var);
306 Tensor* var_tensor = var->tensor();
307 OP_REQUIRES_OK(ctx, CheckState(*var_tensor));
308 using T = StateElementType;
309 OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, T>(
310 ctx, var_tensor, var->copy_on_read_mode.load()));
311 if (read_old_value) {
312 Tensor* output;
313 OP_REQUIRES_OK(
314 ctx, ctx->allocate_output(0, {RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE},
315 &output));
316 auto output_flat = output->flat<T>();
317 if (RNG_MAX_COUNTER_SIZE > GetCounterSize(alg)) {
318 functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
319 output_flat);
320 }
321 functor::DenseUpdate<Device, T, ASSIGN>()(
322 ctx->eigen_device<Device>(), output_flat,
323 const_cast<const Tensor*>(var_tensor)->flat<T>());
324 }
325 switch (alg) {
326 case ConcreteRngAlgorithm::RNG_ALG_PHILOX: {
327 OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor));
328 // var_tensor layout is counter+key, so var_tensor data is also counter
329 // data.
330 auto counter_data = var_tensor->flat<T>().data();
331 functor::RngSkip_Philox<Device>()(ctx->eigen_device<Device>(),
332 counter_data, delta, counter_data);
333 break;
334 }
335 case ConcreteRngAlgorithm::RNG_ALG_THREEFRY: {
336 OP_REQUIRES(
337 ctx, false,
338 errors::Unimplemented(
339 "Non-XLA devices don't support the ThreeFry algorithm."));
340 break;
341 }
342 }
343 }
344};
345
346template <typename T>
347class NonDeterministicIntsOp : public OpKernel {
348 public:
349 explicit NonDeterministicIntsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
350 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
351 }
352
353 void Compute(OpKernelContext* ctx) override {
354 const Tensor& shape_t = ctx->input(0);
355 TensorShape shape;
356 OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
357 Tensor* output;
358 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
359 if (shape.num_elements() == 0) return;
360
361 switch (dtype_) {
362 case DT_INT32:
363 case DT_UINT32:
364 case DT_INT64:
365 case DT_UINT64: {
366 auto output_flat = output->flat<T>();
367 auto data = output_flat.data();
368 for (int64_t i = 0; i < output_flat.size(); ++i) {
369 data[i] = static_cast<T>(random::New64());
370 }
371 break;
372 }
373 default:
374 OP_REQUIRES(ctx, false,
375 errors::InvalidArgument("Unsupported dtype: ",
376 DataTypeString(dtype_)));
377 }
378 }
379
380 private:
381 DataType dtype_;
382};
383
384// So far the 'Distribution' type parameter is only used when the algorithm is
385// philox, so 'NormalDistribution<PhiloxRandom, ...>' is fine for now.
386#define REGISTER_FloatOps(DEVICE, TYPE) \
387 REGISTER_KERNEL_BUILDER( \
388 Name("StatefulStandardNormalV2") \
389 .Device(DEVICE_##DEVICE) \
390 .HostMemory("resource") \
391 .HostMemory("algorithm") \
392 .HostMemory("shape") \
393 .TypeConstraint<TYPE>("dtype"), \
394 StatefulRandomOpV2<DEVICE##Device, \
395 random::NormalDistribution<PhiloxRandom, TYPE> >); \
396 REGISTER_KERNEL_BUILDER( \
397 Name("StatefulUniform") \
398 .Device(DEVICE_##DEVICE) \
399 .HostMemory("resource") \
400 .HostMemory("algorithm") \
401 .HostMemory("shape") \
402 .TypeConstraint<TYPE>("dtype"), \
403 StatefulRandomOpV2<DEVICE##Device, \
404 random::UniformDistribution<PhiloxRandom, TYPE> >); \
405 REGISTER_KERNEL_BUILDER( \
406 Name("StatefulTruncatedNormal") \
407 .Device(DEVICE_##DEVICE) \
408 .HostMemory("resource") \
409 .HostMemory("algorithm") \
410 .HostMemory("shape") \
411 .TypeConstraint<TYPE>("dtype"), \
412 StatefulRandomOpV2< \
413 DEVICE##Device, \
414 random::TruncatedNormalDistribution< \
415 random::SingleSampleAdapter<PhiloxRandom>, TYPE> >);
416
417// CPU also has the deprecated 'StatefulStandardNormal' op for backward
418// compatibility.
419#define REGISTER_FloatOps_CPU(TYPE) \
420 REGISTER_FloatOps(CPU, TYPE) REGISTER_KERNEL_BUILDER( \
421 Name("StatefulStandardNormal") \
422 .Device(DEVICE_CPU) \
423 .HostMemory("resource") \
424 .HostMemory("shape") \
425 .TypeConstraint<TYPE>("dtype"), \
426 StatefulRandomOp<CPUDevice, \
427 random::NormalDistribution<PhiloxRandom, TYPE> >);
428
429#define REGISTER_FloatOps_GPU(TYPE) REGISTER_FloatOps(GPU, TYPE)
430
431TF_CALL_half(REGISTER_FloatOps_CPU);
432TF_CALL_bfloat16(REGISTER_FloatOps_CPU);
433TF_CALL_float(REGISTER_FloatOps_CPU);
434TF_CALL_double(REGISTER_FloatOps_CPU);
435
436#define REGISTER_StatefulUniformInt(DEVICE, TYPE) \
437 REGISTER_KERNEL_BUILDER(Name("StatefulUniformInt") \
438 .Device(DEVICE_##DEVICE) \
439 .HostMemory("resource") \
440 .HostMemory("algorithm") \
441 .HostMemory("shape") \
442 .HostMemory("minval") \
443 .HostMemory("maxval") \
444 .TypeConstraint<TYPE>("dtype"), \
445 StatefulUniformIntOp<DEVICE##Device, TYPE>);
446
447#define REGISTER_StatefulUniformInt_CPU(TYPE) \
448 REGISTER_StatefulUniformInt(CPU, TYPE)
449#define REGISTER_StatefulUniformInt_GPU(TYPE) \
450 REGISTER_StatefulUniformInt(GPU, TYPE)
451
452TF_CALL_int32(REGISTER_StatefulUniformInt_CPU);
453TF_CALL_int64(REGISTER_StatefulUniformInt_CPU);
454
455#define REGISTER_StatefulUniformFullInt(DEVICE, TYPE) \
456 REGISTER_KERNEL_BUILDER(Name("StatefulUniformFullInt") \
457 .Device(DEVICE_##DEVICE) \
458 .HostMemory("resource") \
459 .HostMemory("algorithm") \
460 .HostMemory("shape") \
461 .TypeConstraint<TYPE>("dtype"), \
462 StatefulUniformFullIntOp<DEVICE##Device, TYPE>);
463
464#define REGISTER_StatefulUniformFullInt_CPU(TYPE) \
465 REGISTER_StatefulUniformFullInt(CPU, TYPE)
466#define REGISTER_StatefulUniformFullInt_GPU(TYPE) \
467 REGISTER_StatefulUniformFullInt(GPU, TYPE)
468
469TF_CALL_int32(REGISTER_StatefulUniformFullInt_CPU);
470TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU);
471TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
472TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
473
474// TODO(wangpeng): Remove `HostMemory("delta")` for RngReadAndSkip
475#define REGISTER_RngSkip(DEVICE) \
476 REGISTER_KERNEL_BUILDER(Name("RngSkip") \
477 .Device(DEVICE_##DEVICE) \
478 .HostMemory("resource") \
479 .HostMemory("algorithm") \
480 .HostMemory("delta"), \
481 RngSkipOp<DEVICE##Device>); \
482 REGISTER_KERNEL_BUILDER(Name("RngReadAndSkip") \
483 .Device(DEVICE_##DEVICE) \
484 .HostMemory("resource") \
485 .HostMemory("alg") \
486 .HostMemory("delta"), \
487 RngSkipOp<DEVICE##Device, int32, uint64, true>);
488
489REGISTER_RngSkip(CPU);
490
491#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
492
493TF_CALL_half(REGISTER_FloatOps_GPU);
494TF_CALL_float(REGISTER_FloatOps_GPU);
495TF_CALL_double(REGISTER_FloatOps_GPU);
496TF_CALL_int32(REGISTER_StatefulUniformInt_GPU);
497TF_CALL_int64(REGISTER_StatefulUniformInt_GPU);
498TF_CALL_int32(REGISTER_StatefulUniformFullInt_GPU);
499TF_CALL_int64(REGISTER_StatefulUniformFullInt_GPU);
500TF_CALL_uint32(REGISTER_StatefulUniformFullInt_GPU);
501TF_CALL_uint64(REGISTER_StatefulUniformFullInt_GPU);
502REGISTER_RngSkip(GPU);
503
504#endif // GOOGLE_CUDA
505
506#undef REGISTER_StatefulUniformFullInt_GPU
507#undef REGISTER_StatefulUniformFullInt_CPU
508#undef REGISTER_StatefulUniformFullInt
509#undef REGISTER_StatefulUniformInt_GPU
510#undef REGISTER_StatefulUniformInt_CPU
511#undef REGISTER_StatefulUniformInt
512#undef REGISTER_FloatOps_GPU
513#undef REGISTER_FloatOps_CPU
514#undef REGISTER_FloatOps
515
516#define REGISTER_NonDeterministicInts(TYPE) \
517 REGISTER_KERNEL_BUILDER(Name("NonDeterministicInts") \
518 .Device(DEVICE_CPU) \
519 .HostMemory("shape") \
520 .TypeConstraint<TYPE>("dtype"), \
521 NonDeterministicIntsOp<TYPE>);
522
523TF_CALL_int32(REGISTER_NonDeterministicInts);
524TF_CALL_uint32(REGISTER_NonDeterministicInts);
525TF_CALL_int64(REGISTER_NonDeterministicInts);
526TF_CALL_uint64(REGISTER_NonDeterministicInts);
527
528#undef REGISTER_NonDeterministicInts
529
530// TODO(wangpeng): Add RNG ops for other distributions.
531
532} // end namespace tensorflow
533