1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/rng_alg.h"
19
20namespace tensorflow {
21
22using shape_inference::DimensionHandle;
23using shape_inference::InferenceContext;
24using shape_inference::ShapeHandle;
25
26static Status StatelessShapeV2(InferenceContext* c) {
27 // Check key and counter shapes
28 ShapeHandle key;
29 ShapeHandle counter;
30 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &key));
31 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &counter));
32 shape_inference::ShapeHandle unused_shape;
33 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
34 DimensionHandle unused;
35 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), RNG_KEY_SIZE, &unused));
36
37 // Set output shape
38 ShapeHandle out;
39 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
40 c->set_output(0, out);
41 return OkStatus();
42}
43
44#define REGISTER_STATELESS_OP(name) \
45 REGISTER_OP(name) \
46 .Input("shape: Tshape") \
47 .Input("key: uint64") \
48 .Input("counter: uint64") \
49 .Input("alg: int32") \
50 .Output("output: dtype") \
51 .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
52 .Attr("Tshape: {int32, int64} = DT_INT32") \
53 .SetShapeFn(StatelessShapeV2)
54
55REGISTER_STATELESS_OP("StatelessRandomUniformV2");
56REGISTER_STATELESS_OP("StatelessRandomNormalV2");
57REGISTER_STATELESS_OP("StatelessTruncatedNormalV2");
58
59#undef REGISTER_STATELESS_OP
60
61REGISTER_OP("StatelessRandomUniformIntV2")
62 .Input("shape: Tshape")
63 .Input("key: uint64")
64 .Input("counter: uint64")
65 .Input("alg: int32")
66 .Input("minval: dtype")
67 .Input("maxval: dtype")
68 .Output("output: dtype")
69 .Attr("dtype: {int32, int64, uint32, uint64}")
70 .Attr("Tshape: {int32, int64} = DT_INT32")
71 .SetShapeFn([](InferenceContext* c) {
72 ShapeHandle unused;
73 Status s = c->WithRank(c->input(4), 0, &unused);
74 if (!s.ok()) {
75 return errors::InvalidArgument(
76 "minval must be a scalar; got a tensor of shape ",
77 c->DebugString(c->input(4)));
78 }
79 s = c->WithRank(c->input(5), 0, &unused);
80 if (!s.ok()) {
81 return errors::InvalidArgument(
82 "maxval must be a scalar; got a tensor of shape ",
83 c->DebugString(c->input(5)));
84 }
85 return StatelessShapeV2(c);
86 });
87
88REGISTER_OP("StatelessRandomUniformFullIntV2")
89 .Input("shape: Tshape")
90 .Input("key: uint64")
91 .Input("counter: uint64")
92 .Input("alg: int32")
93 .Output("output: dtype")
94 .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
95 .Attr("Tshape: {int32, int64} = DT_INT32")
96 .SetShapeFn(StatelessShapeV2);
97
98REGISTER_OP("StatelessShuffle")
99 .Input("value: T")
100 .Input("key: uint64")
101 .Input("counter: uint64")
102 .Input("alg: int32")
103 .Output("output: T")
104 .Attr("T: type")
105 .SetShapeFn(shape_inference::UnchangedShape);
106
107REGISTER_OP("StatelessRandomGetKeyCounterAlg")
108 .Input("seed: Tseed")
109 .Output("key: uint64")
110 .Output("counter: uint64")
111 .Output("alg: int32")
112 .Attr("Tseed: {int32, int64} = DT_INT64")
113 .SetShapeFn([](InferenceContext* c) {
114 // Check seed shape
115 ShapeHandle seed;
116 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
117 DimensionHandle unused;
118 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
119
120 // Set output shapes
121 c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
122 c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
123 c->set_output(2, c->MakeShape({}));
124 return OkStatus();
125 });
126
127REGISTER_OP("StatelessRandomGetKeyCounter")
128 .Input("seed: Tseed")
129 .Output("key: uint64")
130 .Output("counter: uint64")
131 .Attr("Tseed: {int32, int64} = DT_INT64")
132 .SetShapeFn([](InferenceContext* c) {
133 // Check seed shape
134 ShapeHandle seed;
135 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
136 DimensionHandle unused;
137 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
138
139 // Set output shapes
140 c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
141 c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
142 return OkStatus();
143 });
144
145REGISTER_OP("StatelessRandomGetAlg")
146 .Output("alg: int32")
147 .SetIsStateful() // because outputs depend on device
148 .SetShapeFn([](InferenceContext* c) {
149 c->set_output(0, c->MakeShape({}));
150 return OkStatus();
151 });
152
153} // namespace tensorflow
154