1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18
19namespace tensorflow {
20
21using shape_inference::DimensionHandle;
22using shape_inference::InferenceContext;
23using shape_inference::ShapeHandle;
24
25static Status StatelessShape(InferenceContext* c) {
26 // Check seed shape
27 ShapeHandle seed;
28 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
29 DimensionHandle unused;
30 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
31
32 // Set output shape
33 ShapeHandle out;
34 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
35 c->set_output(0, out);
36 return OkStatus();
37}
38
39#define REGISTER_STATELESS_OP(name) \
40 REGISTER_OP(name) \
41 .Input("shape: T") \
42 .Input("seed: Tseed") \
43 .Output("output: dtype") \
44 .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
45 .Attr("T: {int32, int64} = DT_INT32") \
46 .Attr("Tseed: {int32, int64} = DT_INT64") \
47 .SetShapeFn(StatelessShape)
48
49REGISTER_STATELESS_OP("StatelessRandomUniform");
50REGISTER_STATELESS_OP("StatelessRandomNormal");
51REGISTER_STATELESS_OP("StatelessTruncatedNormal");
52
53#undef REGISTER_STATELESS_OP
54
55REGISTER_OP("StatelessRandomUniformInt")
56 .Input("shape: T")
57 .Input("seed: Tseed")
58 .Input("minval: dtype")
59 .Input("maxval: dtype")
60 .Output("output: dtype")
61 .Attr("dtype: {int32, int64}")
62 .Attr("T: {int32, int64}")
63 .Attr("Tseed: {int32, int64} = DT_INT64")
64 .SetShapeFn([](InferenceContext* c) {
65 ShapeHandle unused;
66 Status s = c->WithRank(c->input(2), 0, &unused);
67 if (!s.ok()) {
68 return errors::InvalidArgument(
69 "minval must be a scalar; got a tensor of shape ",
70 c->DebugString(c->input(2)));
71 }
72 s = c->WithRank(c->input(3), 0, &unused);
73 if (!s.ok()) {
74 return errors::InvalidArgument(
75 "maxval must be a scalar; got a tensor of shape ",
76 c->DebugString(c->input(3)));
77 }
78 return StatelessShape(c);
79 });
80
81REGISTER_OP("StatelessRandomUniformFullInt")
82 .Input("shape: T")
83 .Input("seed: Tseed")
84 .Output("output: dtype")
85 .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
86 .Attr("T: {int32, int64} = DT_INT32")
87 .Attr("Tseed: {int32, int64, uint32, uint64} = DT_INT64")
88 .SetShapeFn(StatelessShape);
89
90REGISTER_OP("StatelessMultinomial")
91 .Input("logits: T")
92 .Input("num_samples: int32")
93 .Input("seed: Tseed")
94 .Output("output: output_dtype")
95 .Attr("T: realnumbertype")
96 .Attr("Tseed: {int32, int64} = DT_INT64")
97 .Attr("output_dtype: {int32, int64} = DT_INT64")
98 .SetShapeFn([](shape_inference::InferenceContext* c) {
99 // Check seed shape
100 ShapeHandle seed;
101 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &seed));
102 DimensionHandle unused_dim;
103 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim));
104
105 ShapeHandle logits_shape;
106 ShapeHandle unused;
107 DimensionHandle num_samples;
108 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape));
109 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
110 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples));
111 c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples));
112 return OkStatus();
113 });
114
115REGISTER_OP("StatelessRandomBinomial")
116 .Input("shape: S")
117 .Input("seed: Tseed")
118 .Input("counts: T")
119 .Input("probs: T")
120 .Output("output: dtype")
121 .Attr("S: {int32, int64}")
122 .Attr("Tseed: {int32, int64} = DT_INT64")
123 .Attr("T: {half, float, double, int32, int64} = DT_DOUBLE")
124 .Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
125 .SetShapeFn(StatelessShape);
126
127REGISTER_OP("StatelessParameterizedTruncatedNormal")
128 .Input("shape: S")
129 .Input("seed: Tseed")
130 .Input("means: dtype")
131 .Input("stddevs: dtype")
132 .Input("minvals: dtype")
133 .Input("maxvals: dtype")
134 .Output("output: dtype")
135 .Attr("S: {int32, int64}")
136 .Attr("Tseed: {int32, int64} = DT_INT64")
137 .Attr("dtype: {float16, float32, float64}")
138 .SetShapeFn([](InferenceContext* c) {
139 // Check seed shape
140 ShapeHandle seed;
141 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
142 DimensionHandle unused_dim;
143 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused_dim));
144
145 ShapeHandle bcast_means_stddevs;
146 ShapeHandle bcast_except_maxvals;
147 ShapeHandle bcast_all;
148 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
149 c, c->input(2), c->input(3), true, &bcast_means_stddevs));
150 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
151 c, c->input(4), bcast_means_stddevs, true, &bcast_except_maxvals));
152 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
153 c, c->input(5), bcast_except_maxvals, true, &bcast_all));
154
155 // Set output shape
156 ShapeHandle out;
157 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
158 c->set_output(0, out);
159 return OkStatus();
160 });
161
162REGISTER_OP("StatelessRandomPoisson")
163 .Input("shape: T")
164 .Input("seed: Tseed")
165 .Input("lam: Rtype")
166 .Output("output: dtype")
167 .Attr("Rtype: {float16, float32, float64, int32, int64}")
168 .Attr("dtype: {float16, float32, float64, int32, int64}")
169 .Attr("T: {int32, int64}")
170 .Attr("Tseed: {int32, int64} = DT_INT64")
171 .SetShapeFn(StatelessShape);
172
173REGISTER_OP("StatelessRandomGammaV2")
174 .Input("shape: T")
175 .Input("seed: Tseed")
176 .Input("alpha: dtype")
177 .Output("output: dtype")
178 .Attr("dtype: {float16, float32, float64}")
179 .Attr("T: {int32, int64}")
180 .Attr("Tseed: {int32, int64} = DT_INT64")
181 .SetShapeFn(StatelessShape);
182
183} // namespace tensorflow
184