1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::DimensionHandle; |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | static Status StatelessShape(InferenceContext* c) { |
26 | // Check seed shape |
27 | ShapeHandle seed; |
28 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); |
29 | DimensionHandle unused; |
30 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); |
31 | |
32 | // Set output shape |
33 | ShapeHandle out; |
34 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
35 | c->set_output(0, out); |
36 | return OkStatus(); |
37 | } |
38 | |
39 | #define REGISTER_STATELESS_OP(name) \ |
40 | REGISTER_OP(name) \ |
41 | .Input("shape: T") \ |
42 | .Input("seed: Tseed") \ |
43 | .Output("output: dtype") \ |
44 | .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \ |
45 | .Attr("T: {int32, int64} = DT_INT32") \ |
46 | .Attr("Tseed: {int32, int64} = DT_INT64") \ |
47 | .SetShapeFn(StatelessShape) |
48 | |
49 | REGISTER_STATELESS_OP("StatelessRandomUniform" ); |
50 | REGISTER_STATELESS_OP("StatelessRandomNormal" ); |
51 | REGISTER_STATELESS_OP("StatelessTruncatedNormal" ); |
52 | |
53 | #undef REGISTER_STATELESS_OP |
54 | |
55 | REGISTER_OP("StatelessRandomUniformInt" ) |
56 | .Input("shape: T" ) |
57 | .Input("seed: Tseed" ) |
58 | .Input("minval: dtype" ) |
59 | .Input("maxval: dtype" ) |
60 | .Output("output: dtype" ) |
61 | .Attr("dtype: {int32, int64}" ) |
62 | .Attr("T: {int32, int64}" ) |
63 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
64 | .SetShapeFn([](InferenceContext* c) { |
65 | ShapeHandle unused; |
66 | Status s = c->WithRank(c->input(2), 0, &unused); |
67 | if (!s.ok()) { |
68 | return errors::InvalidArgument( |
69 | "minval must be a scalar; got a tensor of shape " , |
70 | c->DebugString(c->input(2))); |
71 | } |
72 | s = c->WithRank(c->input(3), 0, &unused); |
73 | if (!s.ok()) { |
74 | return errors::InvalidArgument( |
75 | "maxval must be a scalar; got a tensor of shape " , |
76 | c->DebugString(c->input(3))); |
77 | } |
78 | return StatelessShape(c); |
79 | }); |
80 | |
81 | REGISTER_OP("StatelessRandomUniformFullInt" ) |
82 | .Input("shape: T" ) |
83 | .Input("seed: Tseed" ) |
84 | .Output("output: dtype" ) |
85 | .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64" ) |
86 | .Attr("T: {int32, int64} = DT_INT32" ) |
87 | .Attr("Tseed: {int32, int64, uint32, uint64} = DT_INT64" ) |
88 | .SetShapeFn(StatelessShape); |
89 | |
90 | REGISTER_OP("StatelessMultinomial" ) |
91 | .Input("logits: T" ) |
92 | .Input("num_samples: int32" ) |
93 | .Input("seed: Tseed" ) |
94 | .Output("output: output_dtype" ) |
95 | .Attr("T: realnumbertype" ) |
96 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
97 | .Attr("output_dtype: {int32, int64} = DT_INT64" ) |
98 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
99 | // Check seed shape |
100 | ShapeHandle seed; |
101 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &seed)); |
102 | DimensionHandle unused_dim; |
103 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim)); |
104 | |
105 | ShapeHandle logits_shape; |
106 | ShapeHandle unused; |
107 | DimensionHandle num_samples; |
108 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape)); |
109 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
110 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples)); |
111 | c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples)); |
112 | return OkStatus(); |
113 | }); |
114 | |
115 | REGISTER_OP("StatelessRandomBinomial" ) |
116 | .Input("shape: S" ) |
117 | .Input("seed: Tseed" ) |
118 | .Input("counts: T" ) |
119 | .Input("probs: T" ) |
120 | .Output("output: dtype" ) |
121 | .Attr("S: {int32, int64}" ) |
122 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
123 | .Attr("T: {half, float, double, int32, int64} = DT_DOUBLE" ) |
124 | .Attr("dtype: {half, float, double, int32, int64} = DT_INT64" ) |
125 | .SetShapeFn(StatelessShape); |
126 | |
127 | REGISTER_OP("StatelessParameterizedTruncatedNormal" ) |
128 | .Input("shape: S" ) |
129 | .Input("seed: Tseed" ) |
130 | .Input("means: dtype" ) |
131 | .Input("stddevs: dtype" ) |
132 | .Input("minvals: dtype" ) |
133 | .Input("maxvals: dtype" ) |
134 | .Output("output: dtype" ) |
135 | .Attr("S: {int32, int64}" ) |
136 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
137 | .Attr("dtype: {float16, float32, float64}" ) |
138 | .SetShapeFn([](InferenceContext* c) { |
139 | // Check seed shape |
140 | ShapeHandle seed; |
141 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); |
142 | DimensionHandle unused_dim; |
143 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim)); |
144 | |
145 | ShapeHandle bcast_means_stddevs; |
146 | ShapeHandle bcast_except_maxvals; |
147 | ShapeHandle bcast_all; |
148 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
149 | c, c->input(2), c->input(3), true, &bcast_means_stddevs)); |
150 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
151 | c, c->input(4), bcast_means_stddevs, true, &bcast_except_maxvals)); |
152 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( |
153 | c, c->input(5), bcast_except_maxvals, true, &bcast_all)); |
154 | |
155 | // Set output shape |
156 | ShapeHandle out; |
157 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
158 | c->set_output(0, out); |
159 | return OkStatus(); |
160 | }); |
161 | |
162 | REGISTER_OP("StatelessRandomPoisson" ) |
163 | .Input("shape: T" ) |
164 | .Input("seed: Tseed" ) |
165 | .Input("lam: Rtype" ) |
166 | .Output("output: dtype" ) |
167 | .Attr("Rtype: {float16, float32, float64, int32, int64}" ) |
168 | .Attr("dtype: {float16, float32, float64, int32, int64}" ) |
169 | .Attr("T: {int32, int64}" ) |
170 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
171 | .SetShapeFn(StatelessShape); |
172 | |
173 | REGISTER_OP("StatelessRandomGammaV2" ) |
174 | .Input("shape: T" ) |
175 | .Input("seed: Tseed" ) |
176 | .Input("alpha: dtype" ) |
177 | .Output("output: dtype" ) |
178 | .Attr("dtype: {float16, float32, float64}" ) |
179 | .Attr("T: {int32, int64}" ) |
180 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
181 | .SetShapeFn(StatelessShape); |
182 | |
183 | } // namespace tensorflow |
184 | |