1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::DimensionHandle; |
23 | using shape_inference::InferenceContext; |
24 | using shape_inference::ShapeHandle; |
25 | |
26 | REGISTER_OP("RandomUniform" ) |
27 | .Input("shape: T" ) |
28 | .SetIsStateful() |
29 | .Output("output: dtype" ) |
30 | .Attr("seed: int = 0" ) |
31 | .Attr("seed2: int = 0" ) |
32 | .Attr("dtype: {half,bfloat16,float,double}" ) |
33 | .Attr("T: {int32, int64}" ) |
34 | .SetShapeFn(shape_inference::RandomShape); |
35 | |
36 | REGISTER_OP("RandomUniformInt" ) |
37 | .Input("shape: T" ) |
38 | .Input("minval: Tout" ) |
39 | .Input("maxval: Tout" ) |
40 | .SetIsStateful() |
41 | .Output("output: Tout" ) |
42 | .Attr("seed: int = 0" ) |
43 | .Attr("seed2: int = 0" ) |
44 | .Attr("Tout: {int32, int64}" ) |
45 | .Attr("T: {int32, int64}" ) |
46 | .SetShapeFn([](InferenceContext* c) { |
47 | ShapeHandle unused; |
48 | Status s = c->WithRank(c->input(1), 0, &unused); |
49 | if (!s.ok()) { |
50 | return errors::InvalidArgument( |
51 | "minval must be a scalar; got a tensor of shape " , |
52 | c->DebugString(c->input(1))); |
53 | } |
54 | s = c->WithRank(c->input(2), 0, &unused); |
55 | if (!s.ok()) { |
56 | return errors::InvalidArgument( |
57 | "maxval must be a scalar; got a tensor of shape " , |
58 | c->DebugString(c->input(2))); |
59 | } |
60 | return shape_inference::RandomShape(c); |
61 | }); |
62 | |
63 | REGISTER_OP("RandomStandardNormal" ) |
64 | .Input("shape: T" ) |
65 | .SetIsStateful() |
66 | .Output("output: dtype" ) |
67 | .Attr("seed: int = 0" ) |
68 | .Attr("seed2: int = 0" ) |
69 | .Attr("dtype: {half,bfloat16,float,double}" ) |
70 | .Attr("T: {int32, int64}" ) |
71 | .SetShapeFn(shape_inference::RandomShape); |
72 | |
73 | REGISTER_OP("ParameterizedTruncatedNormal" ) |
74 | .Input("shape: T" ) |
75 | .Input("means: dtype" ) |
76 | .Input("stdevs: dtype" ) |
77 | .Input("minvals: dtype" ) |
78 | .Input("maxvals: dtype" ) |
79 | .SetIsStateful() |
80 | .Output("output: dtype" ) |
81 | .Attr("seed: int = 0" ) |
82 | .Attr("seed2: int = 0" ) |
83 | .Attr("dtype: {half,bfloat16,float,double}" ) |
84 | .Attr("T: {int32, int64}" ) |
85 | .SetShapeFn([](InferenceContext* c) { |
86 | ShapeHandle unused; |
87 | // Parameters must be 0-d or 1-d. |
88 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused)); |
89 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); |
90 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused)); |
91 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); |
92 | return shape_inference::RandomShape(c); |
93 | }); |
94 | |
95 | REGISTER_OP("TruncatedNormal" ) |
96 | .Input("shape: T" ) |
97 | .SetIsStateful() |
98 | .Output("output: dtype" ) |
99 | .Attr("seed: int = 0" ) |
100 | .Attr("seed2: int = 0" ) |
101 | .Attr("dtype: {half,bfloat16,float,double}" ) |
102 | .Attr("T: {int32, int64}" ) |
103 | .SetShapeFn(shape_inference::RandomShape); |
104 | |
105 | REGISTER_OP("RandomShuffle" ) |
106 | .Input("value: T" ) |
107 | .SetIsStateful() |
108 | .Output("output: T" ) |
109 | .Attr("seed: int = 0" ) |
110 | .Attr("seed2: int = 0" ) |
111 | .Attr("T: type" ) |
112 | .SetShapeFn(shape_inference::UnchangedShape); |
113 | |
114 | REGISTER_OP("Multinomial" ) |
115 | .SetIsStateful() |
116 | .Input("logits: T" ) |
117 | .Input("num_samples: int32" ) |
118 | .Output("output: output_dtype" ) |
119 | .Attr("seed: int = 0" ) |
120 | .Attr("seed2: int = 0" ) |
121 | .Attr("T: realnumbertype" ) |
122 | .Attr("output_dtype: {int32, int64} = DT_INT64" ) |
123 | .SetShapeFn([](InferenceContext* c) { |
124 | ShapeHandle logits_shape; |
125 | ShapeHandle unused; |
126 | DimensionHandle num_samples; |
127 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape)); |
128 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
129 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples)); |
130 | c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples)); |
131 | return OkStatus(); |
132 | }); |
133 | |
134 | REGISTER_OP("RandomGamma" ) |
135 | .SetIsStateful() |
136 | .Input("shape: S" ) |
137 | .Input("alpha: T" ) |
138 | .Output("output: T" ) |
139 | .Attr("seed: int = 0" ) |
140 | .Attr("seed2: int = 0" ) |
141 | .Attr("S: {int32, int64}" ) |
142 | .Attr("T: {half, float, double}" ) |
143 | .SetShapeFn([](InferenceContext* c) { |
144 | ShapeHandle out; |
145 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
146 | TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); |
147 | c->set_output(0, out); |
148 | return OkStatus(); |
149 | }); |
150 | |
151 | REGISTER_OP("RandomGammaGrad" ) |
152 | .Input("alpha: T" ) |
153 | .Input("sample: T" ) |
154 | .Output("output: T" ) |
155 | .Attr("T: {float, double}" ) |
156 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
157 | |
158 | REGISTER_OP("RandomPoisson" ) |
159 | .SetIsStateful() |
160 | .Input("shape: S" ) |
161 | .Input("rate: dtype" ) |
162 | .Output("output: dtype" ) |
163 | .Attr("seed: int = 0" ) |
164 | .Attr("seed2: int = 0" ) |
165 | .Attr("S: {int32, int64}" ) |
166 | .Attr("dtype: {half, float, double}" ) |
167 | .SetShapeFn([](InferenceContext* c) { |
168 | ShapeHandle out; |
169 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
170 | TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); |
171 | c->set_output(0, out); |
172 | return OkStatus(); |
173 | }) |
174 | .Deprecated(25, "Replaced by RandomPoissonV2" ); |
175 | |
176 | REGISTER_OP("RandomPoissonV2" ) |
177 | .SetIsStateful() |
178 | .Input("shape: S" ) |
179 | .Input("rate: R" ) |
180 | .Output("output: dtype" ) |
181 | .Attr("seed: int = 0" ) |
182 | .Attr("seed2: int = 0" ) |
183 | .Attr("S: {int32, int64}" ) |
184 | .Attr("R: {half, float, double, int32, int64} = DT_DOUBLE" ) |
185 | .Attr("dtype: {half, float, double, int32, int64} = DT_INT64" ) |
186 | .SetShapeFn([](InferenceContext* c) { |
187 | ShapeHandle out; |
188 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
189 | TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); |
190 | c->set_output(0, out); |
191 | return OkStatus(); |
192 | }); |
193 | |
194 | } // namespace tensorflow |
195 | |