1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19#include "tensorflow/core/lib/strings/strcat.h"
20
21namespace tensorflow {
22namespace {
23
24constexpr auto kRNNModeAttrs =
25 "rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'";
26
27constexpr auto kRNNInputModeAttrs =
28 "input_mode: {'linear_input', 'skip_input', 'auto_select'} = "
29 "'linear_input'";
30
31constexpr auto kRNNDirectionAttrs =
32 "direction: {'unidirectional', 'bidirectional'} = 'unidirectional'";
33
34} // namespace
35
36using shape_inference::DimensionHandle;
37using shape_inference::InferenceContext;
38using shape_inference::ShapeHandle;
39
40REGISTER_OP("CudnnRNNParamsSize")
41 .Input("num_layers: int32")
42 .Input("num_units: int32")
43 .Input("input_size: int32")
44 .Attr("T: {float16, float32, float64}")
45 .Attr("S: {int32, int64}")
46 .Attr(kRNNModeAttrs)
47 .Attr(kRNNInputModeAttrs)
48 .Attr(kRNNDirectionAttrs)
49 .Attr("dropout: float = 0.0")
50 .Attr("seed: int = 0")
51 .Attr("seed2: int = 0")
52 .Attr("num_proj: int = 0")
53 .Output("params_size: S")
54 .SetShapeFn([](InferenceContext* c) {
55 ShapeHandle unused;
56 // num_layers, num_units, and input_size should be scalars.
57 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
58 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
59 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
60
61 c->set_output(0, c->Vector(1));
62 return OkStatus();
63 });
64
65REGISTER_OP("CudnnRNN")
66 .Input("input: T")
67 .Input("input_h: T")
68 .Input("input_c: T")
69 .Input("params: T")
70 .SetIsStateful()
71 .Output("output: T")
72 .Output("output_h: T")
73 .Output("output_c: T")
74 .Output("reserve_space: T")
75 .Attr("T: {float16, float32, float64}")
76 .Attr(kRNNModeAttrs)
77 .Attr(kRNNInputModeAttrs)
78 .Attr(kRNNDirectionAttrs)
79 .Attr("dropout: float = 0.0")
80 .Attr("seed: int = 0")
81 .Attr("seed2: int = 0")
82 .Attr("is_training: bool = true")
83 .SetShapeFn([](InferenceContext* c) {
84 ShapeHandle unused;
85 auto input_shape = c->input(0);
86 auto input_h_shape = c->input(1);
87 TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
88 TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
89 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
90
91 auto seq_length = c->Dim(input_shape, 0);
92 auto batch_size = c->Dim(input_shape, 1);
93 auto num_units = c->Dim(input_h_shape, 2);
94
95 string direction;
96 TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
97 string rnn_mode;
98 TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
99 int dir_count = (direction == "bidirectional") ? 2 : 1;
100 DimensionHandle output_size;
101 TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
102 auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
103 auto output_h_shape = input_h_shape;
104 auto output_c_shape TF_ATTRIBUTE_UNUSED =
105 (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
106 c->set_output(0, output_shape);
107 c->set_output(1, output_h_shape);
108 c->set_output(2, output_c_shape);
109 c->set_output(3, c->UnknownShape());
110 return OkStatus();
111 });
112
113REGISTER_OP("CudnnRNNV2")
114 .Input("input: T")
115 .Input("input_h: T")
116 .Input("input_c: T")
117 .Input("params: T")
118 .SetIsStateful()
119 .Output("output: T")
120 .Output("output_h: T")
121 .Output("output_c: T")
122 .Output("reserve_space: T")
123 .Output("host_reserved: int8")
124 .Attr("T: {float16, float32, float64}")
125 .Attr(kRNNModeAttrs)
126 .Attr(kRNNInputModeAttrs)
127 .Attr(kRNNDirectionAttrs)
128 .Attr("dropout: float = 0.0")
129 .Attr("seed: int = 0")
130 .Attr("seed2: int = 0")
131 .Attr("is_training: bool = true")
132 .SetShapeFn([](InferenceContext* c) {
133 ShapeHandle unused;
134 auto input_shape = c->input(0);
135 auto input_h_shape = c->input(1);
136 TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
137 TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
138 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
139
140 auto seq_length = c->Dim(input_shape, 0);
141 auto batch_size = c->Dim(input_shape, 1);
142 auto num_units = c->Dim(input_h_shape, 2);
143 string direction;
144 TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
145 string rnn_mode;
146 TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
147 int dir_count = (direction == "bidirectional") ? 2 : 1;
148 DimensionHandle output_size;
149 TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
150 auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
151 auto output_h_shape = input_h_shape;
152 auto output_c_shape TF_ATTRIBUTE_UNUSED =
153 (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
154 c->set_output(0, output_shape);
155 c->set_output(1, output_h_shape);
156 c->set_output(2, output_c_shape);
157 c->set_output(3, c->UnknownShape());
158 c->set_output(4, c->UnknownShape());
159 return OkStatus();
160 });
161
162REGISTER_OP("CudnnRNNV3")
163 .Input("input: T")
164 .Input("input_h: T")
165 .Input("input_c: T")
166 .Input("params: T")
167 .Input("sequence_lengths: int32")
168 .SetIsStateful()
169 .Output("output: T")
170 .Output("output_h: T")
171 .Output("output_c: T")
172 .Output("reserve_space: T")
173 .Output("host_reserved: int8")
174 .Attr("T: {float16, float32, float64}")
175 .Attr(kRNNModeAttrs)
176 .Attr(kRNNInputModeAttrs)
177 .Attr(kRNNDirectionAttrs)
178 .Attr("dropout: float = 0.0")
179 .Attr("seed: int = 0")
180 .Attr("seed2: int = 0")
181 .Attr("num_proj: int = 0")
182 .Attr("is_training: bool = true")
183 .Attr("time_major: bool = true")
184 .SetShapeFn([](InferenceContext* c) {
185 ShapeHandle unused;
186 auto input_shape = c->input(0);
187 auto input_h_shape = c->input(1);
188 auto input_c_shape = c->input(2);
189 TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
190 TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
191 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
192 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused));
193
194 auto max_seq_length = c->Dim(input_shape, 0);
195 auto batch_size = c->Dim(input_shape, 1);
196 auto num_units = c->Dim(input_h_shape, 2);
197
198 string direction;
199 TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
200 string rnn_mode;
201 TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
202 if (rnn_mode == "lstm") {
203 TF_RETURN_IF_ERROR(c->WithRank(input_c_shape, 3, &unused));
204 }
205 int dir_count = (direction == "bidirectional") ? 2 : 1;
206 DimensionHandle output_size;
207 TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
208 auto output_shape =
209 c->MakeShape({max_seq_length, batch_size, output_size});
210 auto output_h_shape = input_h_shape;
211 auto output_c_shape TF_ATTRIBUTE_UNUSED =
212 (rnn_mode == "lstm") ? input_c_shape : c->MakeShape({});
213 c->set_output(0, output_shape);
214 c->set_output(1, output_h_shape);
215 c->set_output(2, output_c_shape);
216 c->set_output(3, c->UnknownShape());
217 c->set_output(4, c->UnknownShape());
218 return OkStatus();
219 });
220
221REGISTER_OP("CudnnRNNBackprop")
222 .Input("input: T")
223 .Input("input_h: T")
224 .Input("input_c: T")
225 .Input("params: T")
226 .Input("output: T")
227 .Input("output_h: T")
228 .Input("output_c: T")
229 .Input("output_backprop: T")
230 .Input("output_h_backprop: T")
231 .Input("output_c_backprop: T")
232 .Input("reserve_space: T")
233 .SetIsStateful()
234 .Output("input_backprop: T")
235 .Output("input_h_backprop: T")
236 .Output("input_c_backprop: T")
237 .Output("params_backprop: T")
238 .Attr("T: {float16, float32, float64}")
239 .Attr(kRNNModeAttrs)
240 .Attr(kRNNInputModeAttrs)
241 .Attr(kRNNDirectionAttrs)
242 .Attr("dropout: float = 0.0")
243 .Attr("seed: int = 0")
244 .Attr("seed2: int = 0")
245 .SetShapeFn([](InferenceContext* c) {
246 auto input_shape = c->input(0);
247 auto input_h_shape = c->input(1);
248 auto input_c_shape = c->input(2);
249 auto params_shape = c->input(3);
250 c->set_output(0, input_shape);
251 c->set_output(1, input_h_shape);
252 c->set_output(2, input_c_shape);
253 c->set_output(3, params_shape);
254 return OkStatus();
255 });
256
257REGISTER_OP("CudnnRNNBackpropV2")
258 .Input("input: T")
259 .Input("input_h: T")
260 .Input("input_c: T")
261 .Input("params: T")
262 .Input("output: T")
263 .Input("output_h: T")
264 .Input("output_c: T")
265 .Input("output_backprop: T")
266 .Input("output_h_backprop: T")
267 .Input("output_c_backprop: T")
268 .Input("reserve_space: T")
269 .Input("host_reserved: int8")
270 .SetIsStateful()
271 .Output("input_backprop: T")
272 .Output("input_h_backprop: T")
273 .Output("input_c_backprop: T")
274 .Output("params_backprop: T")
275 .Attr("T: {float16, float32, float64}")
276 .Attr(kRNNModeAttrs)
277 .Attr(kRNNInputModeAttrs)
278 .Attr(kRNNDirectionAttrs)
279 .Attr("dropout: float = 0.0")
280 .Attr("seed: int = 0")
281 .Attr("seed2: int = 0")
282 .SetShapeFn([](InferenceContext* c) {
283 auto input_shape = c->input(0);
284 auto input_h_shape = c->input(1);
285 auto input_c_shape = c->input(2);
286 auto params_shape = c->input(3);
287 c->set_output(0, input_shape);
288 c->set_output(1, input_h_shape);
289 c->set_output(2, input_c_shape);
290 c->set_output(3, params_shape);
291 return OkStatus();
292 });
293
294REGISTER_OP("CudnnRNNBackpropV3")
295 .Input("input: T")
296 .Input("input_h: T")
297 .Input("input_c: T")
298 .Input("params: T")
299 .Input("sequence_lengths: int32")
300 .Input("output: T")
301 .Input("output_h: T")
302 .Input("output_c: T")
303 .Input("output_backprop: T")
304 .Input("output_h_backprop: T")
305 .Input("output_c_backprop: T")
306 .Input("reserve_space: T")
307 .Input("host_reserved: int8")
308 .SetIsStateful()
309 .Output("input_backprop: T")
310 .Output("input_h_backprop: T")
311 .Output("input_c_backprop: T")
312 .Output("params_backprop: T")
313 .Attr("T: {float16, float32, float64}")
314 .Attr(kRNNModeAttrs)
315 .Attr(kRNNInputModeAttrs)
316 .Attr(kRNNDirectionAttrs)
317 .Attr("dropout: float = 0.0")
318 .Attr("seed: int = 0")
319 .Attr("seed2: int = 0")
320 .Attr("num_proj: int = 0")
321 .Attr("time_major: bool = true")
322 .SetShapeFn([](InferenceContext* c) {
323 auto input_shape = c->input(0);
324 auto input_h_shape = c->input(1);
325 auto input_c_shape = c->input(2);
326 auto params_shape = c->input(3);
327 c->set_output(0, input_shape);
328 c->set_output(1, input_h_shape);
329 c->set_output(2, input_c_shape);
330 c->set_output(3, params_shape);
331 return OkStatus();
332 });
333
334REGISTER_OP("CudnnRNNParamsToCanonical")
335 .Input("num_layers: int32")
336 .Input("num_units: int32")
337 .Input("input_size: int32")
338 .Input("params: T")
339 .Output("weights: num_params * T")
340 .Output("biases: num_params * T")
341 .Attr("T: {float16, float32, float64}")
342 .Attr("num_params: int")
343 .Attr(kRNNModeAttrs)
344 .Attr(kRNNInputModeAttrs)
345 .Attr(kRNNDirectionAttrs)
346 .Attr("dropout: float = 0.0")
347 .Attr("seed: int = 0")
348 .Attr("seed2: int = 0")
349 .SetShapeFn([](InferenceContext* c) {
350 ShapeHandle unused;
351 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
352 int num_params;
353 TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params));
354 // Set shape for weight matrices
355 for (int i = 0; i < num_params; i++) {
356 c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
357 InferenceContext::kUnknownDim));
358 }
359 // Set shape for bias vectors
360 for (int i = 0; i < num_params; i++) {
361 c->set_output(num_params + i, c->Vector(InferenceContext::kUnknownDim));
362 }
363 return OkStatus();
364 });
365
366REGISTER_OP("CudnnRNNParamsToCanonicalV2")
367 .Input("num_layers: int32")
368 .Input("num_units: int32")
369 .Input("input_size: int32")
370 .Input("params: T")
371 .Output("weights: num_params_weights * T")
372 .Output("biases: num_params_biases * T")
373 .Attr("T: {float16, float32, float64}")
374 .Attr("num_params_weights: int")
375 .Attr("num_params_biases: int")
376 .Attr(kRNNModeAttrs)
377 .Attr(kRNNInputModeAttrs)
378 .Attr(kRNNDirectionAttrs)
379 .Attr("dropout: float = 0.0")
380 .Attr("seed: int = 0")
381 .Attr("seed2: int = 0")
382 .Attr("num_proj: int = 0")
383 .SetShapeFn([](InferenceContext* c) {
384 ShapeHandle unused;
385 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
386 int num_params_weights;
387 int num_params_biases;
388 TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights));
389 TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases));
390 // Set shape for weight matrices
391 for (int i = 0; i < num_params_weights; i++) {
392 c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
393 InferenceContext::kUnknownDim));
394 }
395 // Set shape for bias vectors
396 for (int i = 0; i < num_params_biases; i++) {
397 c->set_output(num_params_weights + i,
398 c->Vector(InferenceContext::kUnknownDim));
399 }
400 return OkStatus();
401 });
402
403REGISTER_OP("CudnnRNNCanonicalToParams")
404 .Input("num_layers: int32")
405 .Input("num_units: int32")
406 .Input("input_size: int32")
407 .Input("weights: num_params * T")
408 .Input("biases: num_params * T")
409 .Output("params: T")
410 .Attr("T: {float16, float32, float64}")
411 .Attr("num_params: int")
412 .Attr(kRNNModeAttrs)
413 .Attr(kRNNInputModeAttrs)
414 .Attr(kRNNDirectionAttrs)
415 .Attr("dropout: float = 0.0")
416 .Attr("seed: int = 0")
417 .Attr("seed2: int = 0")
418 .SetShapeFn([](InferenceContext* c) {
419 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
420 return OkStatus();
421 });
422
423REGISTER_OP("CudnnRNNCanonicalToParamsV2")
424 .Input("num_layers: int32")
425 .Input("num_units: int32")
426 .Input("input_size: int32")
427 .Input("weights: num_params_weights * T")
428 .Input("biases: num_params_biases * T")
429 .Output("params: T")
430 .Attr("T: {float16, float32, float64}")
431 .Attr("num_params_weights: int")
432 .Attr("num_params_biases: int")
433 .Attr(kRNNModeAttrs)
434 .Attr(kRNNInputModeAttrs)
435 .Attr(kRNNDirectionAttrs)
436 .Attr("dropout: float = 0.0")
437 .Attr("seed: int = 0")
438 .Attr("seed2: int = 0")
439 .Attr("num_proj: int = 0")
440 .SetShapeFn([](InferenceContext* c) {
441 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
442 return OkStatus();
443 });
444
445} // namespace tensorflow
446