1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21
22using shape_inference::DimensionHandle;
23using shape_inference::InferenceContext;
24using shape_inference::ShapeHandle;
25
26REGISTER_OP("RandomUniform")
27 .Input("shape: T")
28 .SetIsStateful()
29 .Output("output: dtype")
30 .Attr("seed: int = 0")
31 .Attr("seed2: int = 0")
32 .Attr("dtype: {half,bfloat16,float,double}")
33 .Attr("T: {int32, int64}")
34 .SetShapeFn(shape_inference::RandomShape);
35
36REGISTER_OP("RandomUniformInt")
37 .Input("shape: T")
38 .Input("minval: Tout")
39 .Input("maxval: Tout")
40 .SetIsStateful()
41 .Output("output: Tout")
42 .Attr("seed: int = 0")
43 .Attr("seed2: int = 0")
44 .Attr("Tout: {int32, int64}")
45 .Attr("T: {int32, int64}")
46 .SetShapeFn([](InferenceContext* c) {
47 ShapeHandle unused;
48 Status s = c->WithRank(c->input(1), 0, &unused);
49 if (!s.ok()) {
50 return errors::InvalidArgument(
51 "minval must be a scalar; got a tensor of shape ",
52 c->DebugString(c->input(1)));
53 }
54 s = c->WithRank(c->input(2), 0, &unused);
55 if (!s.ok()) {
56 return errors::InvalidArgument(
57 "maxval must be a scalar; got a tensor of shape ",
58 c->DebugString(c->input(2)));
59 }
60 return shape_inference::RandomShape(c);
61 });
62
63REGISTER_OP("RandomStandardNormal")
64 .Input("shape: T")
65 .SetIsStateful()
66 .Output("output: dtype")
67 .Attr("seed: int = 0")
68 .Attr("seed2: int = 0")
69 .Attr("dtype: {half,bfloat16,float,double}")
70 .Attr("T: {int32, int64}")
71 .SetShapeFn(shape_inference::RandomShape);
72
73REGISTER_OP("ParameterizedTruncatedNormal")
74 .Input("shape: T")
75 .Input("means: dtype")
76 .Input("stdevs: dtype")
77 .Input("minvals: dtype")
78 .Input("maxvals: dtype")
79 .SetIsStateful()
80 .Output("output: dtype")
81 .Attr("seed: int = 0")
82 .Attr("seed2: int = 0")
83 .Attr("dtype: {half,bfloat16,float,double}")
84 .Attr("T: {int32, int64}")
85 .SetShapeFn([](InferenceContext* c) {
86 ShapeHandle unused;
87 // Parameters must be 0-d or 1-d.
88 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused));
89 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
90 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused));
91 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
92 return shape_inference::RandomShape(c);
93 });
94
95REGISTER_OP("TruncatedNormal")
96 .Input("shape: T")
97 .SetIsStateful()
98 .Output("output: dtype")
99 .Attr("seed: int = 0")
100 .Attr("seed2: int = 0")
101 .Attr("dtype: {half,bfloat16,float,double}")
102 .Attr("T: {int32, int64}")
103 .SetShapeFn(shape_inference::RandomShape);
104
105REGISTER_OP("RandomShuffle")
106 .Input("value: T")
107 .SetIsStateful()
108 .Output("output: T")
109 .Attr("seed: int = 0")
110 .Attr("seed2: int = 0")
111 .Attr("T: type")
112 .SetShapeFn(shape_inference::UnchangedShape);
113
114REGISTER_OP("Multinomial")
115 .SetIsStateful()
116 .Input("logits: T")
117 .Input("num_samples: int32")
118 .Output("output: output_dtype")
119 .Attr("seed: int = 0")
120 .Attr("seed2: int = 0")
121 .Attr("T: realnumbertype")
122 .Attr("output_dtype: {int32, int64} = DT_INT64")
123 .SetShapeFn([](InferenceContext* c) {
124 ShapeHandle logits_shape;
125 ShapeHandle unused;
126 DimensionHandle num_samples;
127 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape));
128 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
129 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples));
130 c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples));
131 return OkStatus();
132 });
133
134REGISTER_OP("RandomGamma")
135 .SetIsStateful()
136 .Input("shape: S")
137 .Input("alpha: T")
138 .Output("output: T")
139 .Attr("seed: int = 0")
140 .Attr("seed2: int = 0")
141 .Attr("S: {int32, int64}")
142 .Attr("T: {half, float, double}")
143 .SetShapeFn([](InferenceContext* c) {
144 ShapeHandle out;
145 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
146 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
147 c->set_output(0, out);
148 return OkStatus();
149 });
150
151REGISTER_OP("RandomGammaGrad")
152 .Input("alpha: T")
153 .Input("sample: T")
154 .Output("output: T")
155 .Attr("T: {float, double}")
156 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
157
158REGISTER_OP("RandomPoisson")
159 .SetIsStateful()
160 .Input("shape: S")
161 .Input("rate: dtype")
162 .Output("output: dtype")
163 .Attr("seed: int = 0")
164 .Attr("seed2: int = 0")
165 .Attr("S: {int32, int64}")
166 .Attr("dtype: {half, float, double}")
167 .SetShapeFn([](InferenceContext* c) {
168 ShapeHandle out;
169 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
170 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
171 c->set_output(0, out);
172 return OkStatus();
173 })
174 .Deprecated(25, "Replaced by RandomPoissonV2");
175
176REGISTER_OP("RandomPoissonV2")
177 .SetIsStateful()
178 .Input("shape: S")
179 .Input("rate: R")
180 .Output("output: dtype")
181 .Attr("seed: int = 0")
182 .Attr("seed2: int = 0")
183 .Attr("S: {int32, int64}")
184 .Attr("R: {half, float, double, int32, int64} = DT_DOUBLE")
185 .Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
186 .SetShapeFn([](InferenceContext* c) {
187 ShapeHandle out;
188 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
189 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
190 c->set_output(0, out);
191 return OkStatus();
192 });
193
194} // namespace tensorflow
195