1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/rng_alg.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | using shape_inference::DimensionHandle; |
23 | using shape_inference::InferenceContext; |
24 | using shape_inference::ShapeHandle; |
25 | |
26 | static Status StatelessShapeV2(InferenceContext* c) { |
27 | // Check key and counter shapes |
28 | ShapeHandle key; |
29 | ShapeHandle counter; |
30 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &key)); |
31 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &counter)); |
32 | shape_inference::ShapeHandle unused_shape; |
33 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); |
34 | DimensionHandle unused; |
35 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), RNG_KEY_SIZE, &unused)); |
36 | |
37 | // Set output shape |
38 | ShapeHandle out; |
39 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
40 | c->set_output(0, out); |
41 | return OkStatus(); |
42 | } |
43 | |
44 | #define REGISTER_STATELESS_OP(name) \ |
45 | REGISTER_OP(name) \ |
46 | .Input("shape: Tshape") \ |
47 | .Input("key: uint64") \ |
48 | .Input("counter: uint64") \ |
49 | .Input("alg: int32") \ |
50 | .Output("output: dtype") \ |
51 | .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \ |
52 | .Attr("Tshape: {int32, int64} = DT_INT32") \ |
53 | .SetShapeFn(StatelessShapeV2) |
54 | |
55 | REGISTER_STATELESS_OP("StatelessRandomUniformV2" ); |
56 | REGISTER_STATELESS_OP("StatelessRandomNormalV2" ); |
57 | REGISTER_STATELESS_OP("StatelessTruncatedNormalV2" ); |
58 | |
59 | #undef REGISTER_STATELESS_OP |
60 | |
61 | REGISTER_OP("StatelessRandomUniformIntV2" ) |
62 | .Input("shape: Tshape" ) |
63 | .Input("key: uint64" ) |
64 | .Input("counter: uint64" ) |
65 | .Input("alg: int32" ) |
66 | .Input("minval: dtype" ) |
67 | .Input("maxval: dtype" ) |
68 | .Output("output: dtype" ) |
69 | .Attr("dtype: {int32, int64, uint32, uint64}" ) |
70 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
71 | .SetShapeFn([](InferenceContext* c) { |
72 | ShapeHandle unused; |
73 | Status s = c->WithRank(c->input(4), 0, &unused); |
74 | if (!s.ok()) { |
75 | return errors::InvalidArgument( |
76 | "minval must be a scalar; got a tensor of shape " , |
77 | c->DebugString(c->input(4))); |
78 | } |
79 | s = c->WithRank(c->input(5), 0, &unused); |
80 | if (!s.ok()) { |
81 | return errors::InvalidArgument( |
82 | "maxval must be a scalar; got a tensor of shape " , |
83 | c->DebugString(c->input(5))); |
84 | } |
85 | return StatelessShapeV2(c); |
86 | }); |
87 | |
88 | REGISTER_OP("StatelessRandomUniformFullIntV2" ) |
89 | .Input("shape: Tshape" ) |
90 | .Input("key: uint64" ) |
91 | .Input("counter: uint64" ) |
92 | .Input("alg: int32" ) |
93 | .Output("output: dtype" ) |
94 | .Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64" ) |
95 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
96 | .SetShapeFn(StatelessShapeV2); |
97 | |
98 | REGISTER_OP("StatelessShuffle" ) |
99 | .Input("value: T" ) |
100 | .Input("key: uint64" ) |
101 | .Input("counter: uint64" ) |
102 | .Input("alg: int32" ) |
103 | .Output("output: T" ) |
104 | .Attr("T: type" ) |
105 | .SetShapeFn(shape_inference::UnchangedShape); |
106 | |
107 | REGISTER_OP("StatelessRandomGetKeyCounterAlg" ) |
108 | .Input("seed: Tseed" ) |
109 | .Output("key: uint64" ) |
110 | .Output("counter: uint64" ) |
111 | .Output("alg: int32" ) |
112 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
113 | .SetShapeFn([](InferenceContext* c) { |
114 | // Check seed shape |
115 | ShapeHandle seed; |
116 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed)); |
117 | DimensionHandle unused; |
118 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); |
119 | |
120 | // Set output shapes |
121 | c->set_output(0, c->MakeShape({RNG_KEY_SIZE})); |
122 | c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE})); |
123 | c->set_output(2, c->MakeShape({})); |
124 | return OkStatus(); |
125 | }); |
126 | |
127 | REGISTER_OP("StatelessRandomGetKeyCounter" ) |
128 | .Input("seed: Tseed" ) |
129 | .Output("key: uint64" ) |
130 | .Output("counter: uint64" ) |
131 | .Attr("Tseed: {int32, int64} = DT_INT64" ) |
132 | .SetShapeFn([](InferenceContext* c) { |
133 | // Check seed shape |
134 | ShapeHandle seed; |
135 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed)); |
136 | DimensionHandle unused; |
137 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); |
138 | |
139 | // Set output shapes |
140 | c->set_output(0, c->MakeShape({RNG_KEY_SIZE})); |
141 | c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE})); |
142 | return OkStatus(); |
143 | }); |
144 | |
145 | REGISTER_OP("StatelessRandomGetAlg" ) |
146 | .Output("alg: int32" ) |
147 | .SetIsStateful() // because outputs depend on device |
148 | .SetShapeFn([](InferenceContext* c) { |
149 | c->set_output(0, c->MakeShape({})); |
150 | return OkStatus(); |
151 | }); |
152 | |
153 | } // namespace tensorflow |
154 | |