1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::DimensionHandle; |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | REGISTER_OP("GRUBlockCell" ) |
26 | .Attr("T: {float}" ) |
27 | .Input("x: T" ) |
28 | .Input("h_prev: T" ) |
29 | .Input("w_ru: T" ) |
30 | .Input("w_c: T" ) |
31 | .Input("b_ru: T" ) |
32 | .Input("b_c: T" ) |
33 | .Output("r: T" ) |
34 | .Output("u: T" ) |
35 | .Output("c: T" ) |
36 | .Output("h: T" ) |
37 | .SetShapeFn([](InferenceContext* c) { |
38 | ShapeHandle x, h_prev; |
39 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); |
40 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev)); |
41 | |
42 | DimensionHandle batch_size = c->Dim(x, 0); |
43 | DimensionHandle cell_size = c->Dim(h_prev, 1); |
44 | ShapeHandle output = c->Matrix(batch_size, cell_size); |
45 | for (int i = 0; i < 4; ++i) { |
46 | c->set_output(i, output); |
47 | } |
48 | return OkStatus(); |
49 | }); |
50 | |
51 | REGISTER_OP("GRUBlockCellGrad" ) |
52 | .Attr("T: {float}" ) |
53 | .Input("x: T" ) |
54 | .Input("h_prev: T" ) |
55 | .Input("w_ru: T" ) |
56 | .Input("w_c: T" ) |
57 | .Input("b_ru: T" ) |
58 | .Input("b_c: T" ) |
59 | .Input("r: T" ) |
60 | .Input("u: T" ) |
61 | .Input("c: T" ) |
62 | .Input("d_h: T" ) |
63 | .Output("d_x: T" ) |
64 | .Output("d_h_prev: T" ) |
65 | .Output("d_c_bar: T" ) |
66 | .Output("d_r_bar_u_bar: T" ) |
67 | .SetShapeFn([](InferenceContext* c) { |
68 | ShapeHandle x, h_prev, w_ru; |
69 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); |
70 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev)); |
71 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &w_ru)); |
72 | |
73 | DimensionHandle batch_size = c->Dim(x, 0); |
74 | DimensionHandle cell_size = c->Dim(h_prev, 1); |
75 | DimensionHandle twice_cell_size = c->Dim(w_ru, 1); |
76 | ShapeHandle batch_cell_shape = c->Matrix(batch_size, cell_size); |
77 | |
78 | c->set_output(0, x); |
79 | c->set_output(1, batch_cell_shape); |
80 | c->set_output(2, batch_cell_shape); |
81 | c->set_output(3, c->Matrix(batch_size, twice_cell_size)); |
82 | return OkStatus(); |
83 | }); |
84 | |
85 | REGISTER_OP("LSTMBlockCell" ) |
86 | .Input("x: T" ) |
87 | .Input("cs_prev: T" ) |
88 | .Input("h_prev: T" ) |
89 | .Input("w: T" ) |
90 | .Input("wci: T" ) |
91 | .Input("wcf: T" ) |
92 | .Input("wco: T" ) |
93 | .Input("b: T" ) |
94 | .Output("i: T" ) |
95 | .Output("cs: T" ) |
96 | .Output("f: T" ) |
97 | .Output("o: T" ) |
98 | .Output("ci: T" ) |
99 | .Output("co: T" ) |
100 | .Output("h: T" ) |
101 | .Attr("forget_bias: float = 1.0" ) |
102 | .Attr("cell_clip: float = 3.0" ) |
103 | .Attr("use_peephole: bool = false" ) |
104 | .Attr("T: {half, float}" ) |
105 | .SetShapeFn([](InferenceContext* c) { |
106 | ShapeHandle x, cs_prev; |
107 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); |
108 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev)); |
109 | |
110 | DimensionHandle batch_size = c->Dim(x, 0); |
111 | DimensionHandle cell_size = c->Dim(cs_prev, 1); |
112 | ShapeHandle output = c->Matrix(batch_size, cell_size); |
113 | for (int i = 0; i < 7; ++i) { |
114 | c->set_output(i, output); |
115 | } |
116 | return OkStatus(); |
117 | }); |
118 | |
119 | REGISTER_OP("LSTMBlockCellGrad" ) |
120 | .Input("x: T" ) |
121 | .Input("cs_prev: T" ) |
122 | .Input("h_prev: T" ) |
123 | .Input("w: T" ) |
124 | .Input("wci: T" ) |
125 | .Input("wcf: T" ) |
126 | .Input("wco: T" ) |
127 | .Input("b: T" ) |
128 | .Input("i: T" ) |
129 | .Input("cs: T" ) |
130 | .Input("f: T" ) |
131 | .Input("o: T" ) |
132 | .Input("ci: T" ) |
133 | .Input("co: T" ) |
134 | .Input("cs_grad: T" ) |
135 | .Input("h_grad: T" ) |
136 | .Output("cs_prev_grad: T" ) |
137 | .Output("dicfo: T" ) |
138 | .Output("wci_grad: T" ) |
139 | .Output("wcf_grad: T" ) |
140 | .Output("wco_grad: T" ) |
141 | .Attr("use_peephole: bool" ) |
142 | .Attr("T: {half, float}" ) |
143 | .SetShapeFn([](InferenceContext* c) { |
144 | ShapeHandle x, cs_prev; |
145 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); |
146 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev)); |
147 | |
148 | DimensionHandle batch_size = c->Dim(x, 0); |
149 | DimensionHandle cell_size = c->Dim(cs_prev, 1); |
150 | DimensionHandle cell_size_times_4; |
151 | TF_RETURN_IF_ERROR(c->Multiply(cell_size, 4, &cell_size_times_4)); |
152 | ShapeHandle cell_size_vec = c->Vector(cell_size); |
153 | |
154 | c->set_output(0, c->Matrix(batch_size, cell_size)); |
155 | c->set_output(1, c->Matrix(batch_size, cell_size_times_4)); |
156 | c->set_output(2, cell_size_vec); |
157 | c->set_output(3, cell_size_vec); |
158 | c->set_output(4, cell_size_vec); |
159 | return OkStatus(); |
160 | }); |
161 | |
162 | REGISTER_OP("BlockLSTM" ) |
163 | .Input("seq_len_max: int64" ) |
164 | .Input("x: T" ) |
165 | .Input("cs_prev: T" ) |
166 | .Input("h_prev: T" ) |
167 | .Input("w: T" ) |
168 | .Input("wci: T" ) |
169 | .Input("wcf: T" ) |
170 | .Input("wco: T" ) |
171 | .Input("b: T" ) |
172 | .Output("i: T" ) |
173 | .Output("cs: T" ) |
174 | .Output("f: T" ) |
175 | .Output("o: T" ) |
176 | .Output("ci: T" ) |
177 | .Output("co: T" ) |
178 | .Output("h: T" ) |
179 | .Attr("forget_bias: float = 1.0" ) |
180 | .Attr("cell_clip: float = 3.0" ) |
181 | .Attr("use_peephole: bool = false" ) |
182 | .Attr("T: {half, float}" ) |
183 | .SetShapeFn([](InferenceContext* c) { |
184 | ShapeHandle x, b; |
185 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); |
186 | TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b)); |
187 | |
188 | DimensionHandle timelen = c->Dim(x, 0); |
189 | DimensionHandle batch_size = c->Dim(x, 1); |
190 | DimensionHandle cell_size; |
191 | TF_RETURN_IF_ERROR( |
192 | c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size)); |
193 | |
194 | DCHECK_EQ(7, c->num_outputs()); |
195 | ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size}); |
196 | for (int i = 0; i < 7; ++i) { |
197 | c->set_output(i, output); |
198 | } |
199 | return OkStatus(); |
200 | }); |
201 | |
202 | REGISTER_OP("BlockLSTMV2" ) |
203 | .Input("seq_len_max: int64" ) |
204 | .Input("x: T" ) |
205 | .Input("cs_prev: T" ) |
206 | .Input("h_prev: T" ) |
207 | .Input("w: T" ) |
208 | .Input("wci: T" ) |
209 | .Input("wcf: T" ) |
210 | .Input("wco: T" ) |
211 | .Input("b: T" ) |
212 | .Output("i: T" ) |
213 | .Output("cs: T" ) |
214 | .Output("f: T" ) |
215 | .Output("o: T" ) |
216 | .Output("ci: T" ) |
217 | .Output("co: T" ) |
218 | .Output("h: T" ) |
219 | .Attr("cell_clip: float = 0.0" ) |
220 | .Attr("use_peephole: bool = false" ) |
221 | .Attr("T: {half, float}" ) |
222 | .SetShapeFn([](InferenceContext* c) { |
223 | ShapeHandle x, b; |
224 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); |
225 | TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b)); |
226 | |
227 | DimensionHandle timelen = c->Dim(x, 0); |
228 | DimensionHandle batch_size = c->Dim(x, 1); |
229 | DimensionHandle cell_size; |
230 | TF_RETURN_IF_ERROR( |
231 | c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size)); |
232 | |
233 | DCHECK_EQ(7, c->num_outputs()); |
234 | ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size}); |
235 | for (int i = 0; i < 7; ++i) { |
236 | c->set_output(i, output); |
237 | } |
238 | return OkStatus(); |
239 | }); |
240 | |
241 | REGISTER_OP("BlockLSTMGrad" ) |
242 | .Input("seq_len_max: int64" ) |
243 | .Input("x: T" ) |
244 | .Input("cs_prev: T" ) |
245 | .Input("h_prev: T" ) |
246 | .Input("w: T" ) |
247 | .Input("wci: T" ) |
248 | .Input("wcf: T" ) |
249 | .Input("wco: T" ) |
250 | .Input("b: T" ) |
251 | .Input("i: T" ) |
252 | .Input("cs: T" ) |
253 | .Input("f: T" ) |
254 | .Input("o: T" ) |
255 | .Input("ci: T" ) |
256 | .Input("co: T" ) |
257 | .Input("h: T" ) |
258 | .Input("cs_grad: T" ) |
259 | .Input("h_grad: T" ) |
260 | .Output("x_grad: T" ) |
261 | .Output("cs_prev_grad: T" ) |
262 | .Output("h_prev_grad: T" ) |
263 | .Output("w_grad: T" ) |
264 | .Output("wci_grad: T" ) |
265 | .Output("wcf_grad: T" ) |
266 | .Output("wco_grad: T" ) |
267 | .Output("b_grad: T" ) |
268 | .Attr("use_peephole: bool" ) |
269 | .Attr("T: {half, float}" ) |
270 | .SetShapeFn([](InferenceContext* c) { |
271 | ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b; |
272 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); |
273 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev)); |
274 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev)); |
275 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w)); |
276 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci)); |
277 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco)); |
278 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf)); |
279 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b)); |
280 | |
281 | c->set_output(0, x); |
282 | c->set_output(1, cs_prev); |
283 | c->set_output(2, h_prev); |
284 | c->set_output(3, w); |
285 | c->set_output(4, wci); |
286 | c->set_output(5, wco); |
287 | c->set_output(6, wcf); |
288 | c->set_output(7, b); |
289 | |
290 | return OkStatus(); |
291 | }); |
292 | |
293 | REGISTER_OP("BlockLSTMGradV2" ) |
294 | .Input("seq_len_max: int64" ) |
295 | .Input("x: T" ) |
296 | .Input("cs_prev: T" ) |
297 | .Input("h_prev: T" ) |
298 | .Input("w: T" ) |
299 | .Input("wci: T" ) |
300 | .Input("wcf: T" ) |
301 | .Input("wco: T" ) |
302 | .Input("b: T" ) |
303 | .Input("i: T" ) |
304 | .Input("cs: T" ) |
305 | .Input("f: T" ) |
306 | .Input("o: T" ) |
307 | .Input("ci: T" ) |
308 | .Input("co: T" ) |
309 | .Input("h: T" ) |
310 | .Input("cs_grad: T" ) |
311 | .Input("h_grad: T" ) |
312 | .Output("x_grad: T" ) |
313 | .Output("cs_prev_grad: T" ) |
314 | .Output("h_prev_grad: T" ) |
315 | .Output("w_grad: T" ) |
316 | .Output("wci_grad: T" ) |
317 | .Output("wcf_grad: T" ) |
318 | .Output("wco_grad: T" ) |
319 | .Output("b_grad: T" ) |
320 | .Attr("use_peephole: bool" ) |
321 | .Attr("T: {half, float}" ) |
322 | .SetShapeFn([](InferenceContext* c) { |
323 | ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b; |
324 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); |
325 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev)); |
326 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev)); |
327 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w)); |
328 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci)); |
329 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco)); |
330 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf)); |
331 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b)); |
332 | |
333 | c->set_output(0, x); |
334 | c->set_output(1, cs_prev); |
335 | c->set_output(2, h_prev); |
336 | c->set_output(3, w); |
337 | c->set_output(4, wci); |
338 | c->set_output(5, wco); |
339 | c->set_output(6, wcf); |
340 | c->set_output(7, b); |
341 | |
342 | return OkStatus(); |
343 | }); |
344 | |
345 | } // end namespace tensorflow |
346 | |