1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/shape_inference.h"
18
19namespace tensorflow {
20
21using shape_inference::DimensionHandle;
22using shape_inference::InferenceContext;
23using shape_inference::ShapeHandle;
24
25REGISTER_OP("GRUBlockCell")
26 .Attr("T: {float}")
27 .Input("x: T")
28 .Input("h_prev: T")
29 .Input("w_ru: T")
30 .Input("w_c: T")
31 .Input("b_ru: T")
32 .Input("b_c: T")
33 .Output("r: T")
34 .Output("u: T")
35 .Output("c: T")
36 .Output("h: T")
37 .SetShapeFn([](InferenceContext* c) {
38 ShapeHandle x, h_prev;
39 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
40 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));
41
42 DimensionHandle batch_size = c->Dim(x, 0);
43 DimensionHandle cell_size = c->Dim(h_prev, 1);
44 ShapeHandle output = c->Matrix(batch_size, cell_size);
45 for (int i = 0; i < 4; ++i) {
46 c->set_output(i, output);
47 }
48 return OkStatus();
49 });
50
51REGISTER_OP("GRUBlockCellGrad")
52 .Attr("T: {float}")
53 .Input("x: T")
54 .Input("h_prev: T")
55 .Input("w_ru: T")
56 .Input("w_c: T")
57 .Input("b_ru: T")
58 .Input("b_c: T")
59 .Input("r: T")
60 .Input("u: T")
61 .Input("c: T")
62 .Input("d_h: T")
63 .Output("d_x: T")
64 .Output("d_h_prev: T")
65 .Output("d_c_bar: T")
66 .Output("d_r_bar_u_bar: T")
67 .SetShapeFn([](InferenceContext* c) {
68 ShapeHandle x, h_prev, w_ru;
69 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
70 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));
71 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &w_ru));
72
73 DimensionHandle batch_size = c->Dim(x, 0);
74 DimensionHandle cell_size = c->Dim(h_prev, 1);
75 DimensionHandle twice_cell_size = c->Dim(w_ru, 1);
76 ShapeHandle batch_cell_shape = c->Matrix(batch_size, cell_size);
77
78 c->set_output(0, x);
79 c->set_output(1, batch_cell_shape);
80 c->set_output(2, batch_cell_shape);
81 c->set_output(3, c->Matrix(batch_size, twice_cell_size));
82 return OkStatus();
83 });
84
85REGISTER_OP("LSTMBlockCell")
86 .Input("x: T")
87 .Input("cs_prev: T")
88 .Input("h_prev: T")
89 .Input("w: T")
90 .Input("wci: T")
91 .Input("wcf: T")
92 .Input("wco: T")
93 .Input("b: T")
94 .Output("i: T")
95 .Output("cs: T")
96 .Output("f: T")
97 .Output("o: T")
98 .Output("ci: T")
99 .Output("co: T")
100 .Output("h: T")
101 .Attr("forget_bias: float = 1.0")
102 .Attr("cell_clip: float = 3.0")
103 .Attr("use_peephole: bool = false")
104 .Attr("T: {half, float}")
105 .SetShapeFn([](InferenceContext* c) {
106 ShapeHandle x, cs_prev;
107 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
108 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev));
109
110 DimensionHandle batch_size = c->Dim(x, 0);
111 DimensionHandle cell_size = c->Dim(cs_prev, 1);
112 ShapeHandle output = c->Matrix(batch_size, cell_size);
113 for (int i = 0; i < 7; ++i) {
114 c->set_output(i, output);
115 }
116 return OkStatus();
117 });
118
119REGISTER_OP("LSTMBlockCellGrad")
120 .Input("x: T")
121 .Input("cs_prev: T")
122 .Input("h_prev: T")
123 .Input("w: T")
124 .Input("wci: T")
125 .Input("wcf: T")
126 .Input("wco: T")
127 .Input("b: T")
128 .Input("i: T")
129 .Input("cs: T")
130 .Input("f: T")
131 .Input("o: T")
132 .Input("ci: T")
133 .Input("co: T")
134 .Input("cs_grad: T")
135 .Input("h_grad: T")
136 .Output("cs_prev_grad: T")
137 .Output("dicfo: T")
138 .Output("wci_grad: T")
139 .Output("wcf_grad: T")
140 .Output("wco_grad: T")
141 .Attr("use_peephole: bool")
142 .Attr("T: {half, float}")
143 .SetShapeFn([](InferenceContext* c) {
144 ShapeHandle x, cs_prev;
145 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
146 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev));
147
148 DimensionHandle batch_size = c->Dim(x, 0);
149 DimensionHandle cell_size = c->Dim(cs_prev, 1);
150 DimensionHandle cell_size_times_4;
151 TF_RETURN_IF_ERROR(c->Multiply(cell_size, 4, &cell_size_times_4));
152 ShapeHandle cell_size_vec = c->Vector(cell_size);
153
154 c->set_output(0, c->Matrix(batch_size, cell_size));
155 c->set_output(1, c->Matrix(batch_size, cell_size_times_4));
156 c->set_output(2, cell_size_vec);
157 c->set_output(3, cell_size_vec);
158 c->set_output(4, cell_size_vec);
159 return OkStatus();
160 });
161
162REGISTER_OP("BlockLSTM")
163 .Input("seq_len_max: int64")
164 .Input("x: T")
165 .Input("cs_prev: T")
166 .Input("h_prev: T")
167 .Input("w: T")
168 .Input("wci: T")
169 .Input("wcf: T")
170 .Input("wco: T")
171 .Input("b: T")
172 .Output("i: T")
173 .Output("cs: T")
174 .Output("f: T")
175 .Output("o: T")
176 .Output("ci: T")
177 .Output("co: T")
178 .Output("h: T")
179 .Attr("forget_bias: float = 1.0")
180 .Attr("cell_clip: float = 3.0")
181 .Attr("use_peephole: bool = false")
182 .Attr("T: {half, float}")
183 .SetShapeFn([](InferenceContext* c) {
184 ShapeHandle x, b;
185 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
186 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));
187
188 DimensionHandle timelen = c->Dim(x, 0);
189 DimensionHandle batch_size = c->Dim(x, 1);
190 DimensionHandle cell_size;
191 TF_RETURN_IF_ERROR(
192 c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));
193
194 DCHECK_EQ(7, c->num_outputs());
195 ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size});
196 for (int i = 0; i < 7; ++i) {
197 c->set_output(i, output);
198 }
199 return OkStatus();
200 });
201
202REGISTER_OP("BlockLSTMV2")
203 .Input("seq_len_max: int64")
204 .Input("x: T")
205 .Input("cs_prev: T")
206 .Input("h_prev: T")
207 .Input("w: T")
208 .Input("wci: T")
209 .Input("wcf: T")
210 .Input("wco: T")
211 .Input("b: T")
212 .Output("i: T")
213 .Output("cs: T")
214 .Output("f: T")
215 .Output("o: T")
216 .Output("ci: T")
217 .Output("co: T")
218 .Output("h: T")
219 .Attr("cell_clip: float = 0.0")
220 .Attr("use_peephole: bool = false")
221 .Attr("T: {half, float}")
222 .SetShapeFn([](InferenceContext* c) {
223 ShapeHandle x, b;
224 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
225 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));
226
227 DimensionHandle timelen = c->Dim(x, 0);
228 DimensionHandle batch_size = c->Dim(x, 1);
229 DimensionHandle cell_size;
230 TF_RETURN_IF_ERROR(
231 c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));
232
233 DCHECK_EQ(7, c->num_outputs());
234 ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size});
235 for (int i = 0; i < 7; ++i) {
236 c->set_output(i, output);
237 }
238 return OkStatus();
239 });
240
241REGISTER_OP("BlockLSTMGrad")
242 .Input("seq_len_max: int64")
243 .Input("x: T")
244 .Input("cs_prev: T")
245 .Input("h_prev: T")
246 .Input("w: T")
247 .Input("wci: T")
248 .Input("wcf: T")
249 .Input("wco: T")
250 .Input("b: T")
251 .Input("i: T")
252 .Input("cs: T")
253 .Input("f: T")
254 .Input("o: T")
255 .Input("ci: T")
256 .Input("co: T")
257 .Input("h: T")
258 .Input("cs_grad: T")
259 .Input("h_grad: T")
260 .Output("x_grad: T")
261 .Output("cs_prev_grad: T")
262 .Output("h_prev_grad: T")
263 .Output("w_grad: T")
264 .Output("wci_grad: T")
265 .Output("wcf_grad: T")
266 .Output("wco_grad: T")
267 .Output("b_grad: T")
268 .Attr("use_peephole: bool")
269 .Attr("T: {half, float}")
270 .SetShapeFn([](InferenceContext* c) {
271 ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
272 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
273 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev));
274 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev));
275 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w));
276 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci));
277 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco));
278 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf));
279 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b));
280
281 c->set_output(0, x);
282 c->set_output(1, cs_prev);
283 c->set_output(2, h_prev);
284 c->set_output(3, w);
285 c->set_output(4, wci);
286 c->set_output(5, wco);
287 c->set_output(6, wcf);
288 c->set_output(7, b);
289
290 return OkStatus();
291 });
292
293REGISTER_OP("BlockLSTMGradV2")
294 .Input("seq_len_max: int64")
295 .Input("x: T")
296 .Input("cs_prev: T")
297 .Input("h_prev: T")
298 .Input("w: T")
299 .Input("wci: T")
300 .Input("wcf: T")
301 .Input("wco: T")
302 .Input("b: T")
303 .Input("i: T")
304 .Input("cs: T")
305 .Input("f: T")
306 .Input("o: T")
307 .Input("ci: T")
308 .Input("co: T")
309 .Input("h: T")
310 .Input("cs_grad: T")
311 .Input("h_grad: T")
312 .Output("x_grad: T")
313 .Output("cs_prev_grad: T")
314 .Output("h_prev_grad: T")
315 .Output("w_grad: T")
316 .Output("wci_grad: T")
317 .Output("wcf_grad: T")
318 .Output("wco_grad: T")
319 .Output("b_grad: T")
320 .Attr("use_peephole: bool")
321 .Attr("T: {half, float}")
322 .SetShapeFn([](InferenceContext* c) {
323 ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
324 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
325 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev));
326 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev));
327 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w));
328 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci));
329 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco));
330 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf));
331 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b));
332
333 c->set_output(0, x);
334 c->set_output(1, cs_prev);
335 c->set_output(2, h_prev);
336 c->set_output(3, w);
337 c->set_output(4, wci);
338 c->set_output(5, wco);
339 c->set_output(6, wcf);
340 c->set_output(7, b);
341
342 return OkStatus();
343 });
344
345} // end namespace tensorflow
346