1 | #include <gtest/gtest.h> |
2 | |
3 | #include <torch/torch.h> |
4 | |
5 | #include <test/cpp/api/support.h> |
6 | |
7 | using namespace torch::nn; |
8 | |
9 | struct TransformerTest : torch::test::SeedingFixture {}; |
10 | |
11 | // a generic function to set constants for parameters so we have fixed result |
12 | // for deterministic test |
13 | template <typename Model> |
14 | void set_parameter_to_constants( |
15 | Model& model, |
16 | const torch::TensorOptions& tensor_options) { |
17 | torch::NoGradGuard guard; |
18 | for (auto& p : model->parameters()) { |
19 | auto sz = p.view(-1).size(0); |
20 | p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes()))); |
21 | } |
22 | } |
23 | |
24 | // a generic function to provide consistent encoder/decoder layer for all the |
25 | // transformer tests |
26 | template <typename T_LAYER, typename T_OPTIONS> |
27 | T_LAYER get_a_test_layer( |
28 | const torch::TensorOptions& tensor_options, |
29 | bool use_callable_activation) { |
30 | int64_t d_model = 4; |
31 | int64_t nhead = 2; |
32 | int64_t dim_feedforward = 16; |
33 | double dropout = 0.0; |
34 | |
35 | // activation is always ReLU here and it can be adjusted later depending on |
36 | // the usage |
37 | T_LAYER layer(T_OPTIONS(d_model, nhead) |
38 | .dim_feedforward(dim_feedforward) |
39 | .dropout(dropout)); |
40 | if (tensor_options.device() == torch::kCUDA) { |
41 | layer->to(torch::kCUDA); |
42 | } |
43 | if (use_callable_activation) { |
44 | layer.get()->options.activation( |
45 | [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); }); |
46 | } |
47 | |
48 | // set constant weights of the model |
49 | set_parameter_to_constants<T_LAYER>(layer, tensor_options); |
50 | |
51 | return layer; |
52 | } |
53 | |
54 | void transformer_encoder_layer_test_helper( |
55 | bool is_cuda, |
56 | bool use_callable_activation) { |
57 | // this is a deterministic test for TransformerEncoderLayer |
58 | torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; |
59 | torch::TensorOptions tensor_options = |
60 | torch::TensorOptions().dtype(torch::kFloat32).device(device); |
61 | |
62 | TransformerEncoderLayer model = |
63 | get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>( |
64 | tensor_options, use_callable_activation); |
65 | |
66 | // relu test case 1 |
67 | torch::Tensor encoder_input = |
68 | torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
69 | torch::Tensor result = model(encoder_input).detach(); |
70 | torch::Tensor ref_output = torch::tensor( |
71 | {{{2.258703, 0.127985, -0.697881, 0.170862}}}, tensor_options); |
72 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
73 | ASSERT_TRUE( |
74 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
75 | |
76 | // all 0 values are NOT masked. This should't mask anything |
77 | torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1; |
78 | result = model( |
79 | encoder_input, |
80 | /*src_mask=*/torch::Tensor{}, |
81 | /*src_key_padding_mask=*/mask) |
82 | .detach(); |
83 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
84 | ASSERT_TRUE( |
85 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
86 | |
87 | // all 1 values are masked. Since there is only 1 input embedding this will |
88 | // result in nan. |
89 | mask = torch::tensor({{1}}, tensor_options) == 1; |
90 | result = model( |
91 | encoder_input, |
92 | /*src_mask=*/torch::Tensor{}, |
93 | /*src_key_padding_mask=*/mask) |
94 | .detach(); |
95 | ASSERT_TRUE(torch::isnan(result).all().item().to<bool>()); |
96 | |
97 | // relu test case 2 |
98 | encoder_input = |
99 | torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); |
100 | result = model(encoder_input).detach(); |
101 | ref_output = torch::tensor( |
102 | {{{2.272644, 0.119035, -0.691669, 0.153486}}, |
103 | {{2.272644, 0.119035, -0.691669, 0.153486}}}, |
104 | tensor_options); |
105 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
106 | ASSERT_TRUE( |
107 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
108 | |
109 | // all 0 values are NOT masked |
110 | mask = torch::tensor({{0, 0}}, tensor_options) == 1; |
111 | result = model( |
112 | encoder_input, |
113 | /*src_mask=*/torch::Tensor{}, |
114 | /*src_key_padding_mask=*/mask) |
115 | .detach(); |
116 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
117 | ASSERT_TRUE( |
118 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
119 | |
120 | // mask with 1 and 0 |
121 | mask = torch::tensor({{1, 0}}, tensor_options) == 1; |
122 | result = model( |
123 | encoder_input, |
124 | /*src_mask=*/torch::Tensor{}, |
125 | /*src_key_padding_mask=*/mask) |
126 | .detach(); |
127 | ref_output = torch::tensor( |
128 | {{{2.301516, 0.092249, -0.679101, 0.103088}}, |
129 | {{2.301516, 0.092249, -0.679101, 0.103088}}}, |
130 | tensor_options); |
131 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
132 | ASSERT_TRUE( |
133 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
134 | |
135 | // relu test case 3 |
136 | encoder_input = torch::tensor( |
137 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
138 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
139 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
140 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
141 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
142 | tensor_options); |
143 | result = model(encoder_input).detach(); |
144 | ref_output = torch::tensor( |
145 | {{{2.428589, 0.020835, -0.602055, -0.085249}, |
146 | {2.427987, 0.021213, -0.602496, -0.084103}}, |
147 | {{2.424689, 0.019155, -0.604793, -0.085672}, |
148 | {2.413863, 0.022211, -0.612486, -0.072490}}, |
149 | {{2.433774, 0.021598, -0.598343, -0.087548}, |
150 | {2.425104, 0.019748, -0.604515, -0.084839}}, |
151 | {{2.436185, 0.022682, -0.596625, -0.087261}, |
152 | {2.433556, 0.021891, -0.598509, -0.086832}}, |
153 | {{2.416246, 0.017512, -0.610712, -0.082961}, |
154 | {2.422901, 0.024187, -0.606178, -0.074929}}}, |
155 | tensor_options); |
156 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
157 | ASSERT_TRUE( |
158 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
159 | |
160 | // all 0 values are NOT masked |
161 | mask = torch::zeros({2, 5}, tensor_options) == 1; |
162 | result = model( |
163 | encoder_input, |
164 | /*src_mask=*/torch::Tensor{}, |
165 | /*src_key_padding_mask=*/mask) |
166 | .detach(); |
167 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
168 | ASSERT_TRUE( |
169 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
170 | |
171 | // mask with 0s and 1s |
172 | mask[0][1] = 1; |
173 | mask[1][3] = 1; |
174 | mask[1][4] = 1; |
175 | result = model( |
176 | encoder_input, |
177 | /*src_mask=*/torch::Tensor{}, |
178 | /*src_key_padding_mask=*/mask) |
179 | .detach(); |
180 | ref_output = torch::tensor( |
181 | {{{2.429026, 0.020793, -0.601741, -0.085642}, |
182 | {2.428811, 0.021445, -0.601912, -0.084252}}, |
183 | {{2.425009, 0.019155, -0.604566, -0.085899}, |
184 | {2.415408, 0.02249, -0.611415, -0.073}}, |
185 | {{2.434199, 0.021682, -0.598039, -0.087699}, |
186 | {2.42598, 0.019941, -0.603896, -0.085091}}, |
187 | {{2.436457, 0.022736, -0.59643, -0.08736}, |
188 | {2.434021, 0.022093, -0.598179, -0.08679}}, |
189 | {{2.416531, 0.017498, -0.610513, -0.083181}, |
190 | {2.4242, 0.024653, -0.605266, -0.074959}}}, |
191 | tensor_options); |
192 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
193 | ASSERT_TRUE( |
194 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
195 | |
196 | // gelu test case 1 |
197 | model.get()->options.activation(torch::kGELU); |
198 | encoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
199 | result = model(encoder_input).detach(); |
200 | ref_output = torch::tensor( |
201 | {{{2.249815, 0.131006, -0.702199, 0.177868}}}, tensor_options); |
202 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
203 | ASSERT_TRUE( |
204 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
205 | |
206 | // gelu test case 2 |
207 | encoder_input = torch::tensor( |
208 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
209 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
210 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
211 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
212 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
213 | tensor_options); |
214 | result = model(encoder_input); |
215 | ref_output = torch::tensor( |
216 | {{{2.42163188, 0.03227153, -0.60714219, -0.05908082}, |
217 | {2.42151276, 0.03302179, -0.60722523, -0.05762651}}, |
218 | {{2.41926761, 0.02974034, -0.60879519, -0.0621269}, |
219 | {2.41626395, 0.03539356, -0.61087842, -0.04978623}}, |
220 | {{2.42382808, 0.03218872, -0.6055963, -0.06073591}, |
221 | {2.41983477, 0.03085259, -0.60840145, -0.06046414}}, |
222 | {{2.42500749, 0.03328855, -0.60476388, -0.0595334}, |
223 | {2.4237977, 0.03290575, -0.60561789, -0.05940082}}, |
224 | {{2.41383916, 0.02686345, -0.61256377, -0.06380707}, |
225 | {2.42000277, 0.03800944, -0.60824798, -0.04754947}}}, |
226 | tensor_options); |
227 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
228 | ASSERT_TRUE( |
229 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
230 | } |
231 | |
232 | TEST_F(TransformerTest, TransformerEncoderLayer) { |
233 | transformer_encoder_layer_test_helper( |
234 | /*is_cuda=*/false, /*use_callable_activation=*/false); |
235 | transformer_encoder_layer_test_helper( |
236 | /*is_cuda=*/false, /*use_callable_activation=*/true); |
237 | } |
238 | |
239 | TEST_F(TransformerTest, TransformerEncoderLayer_CUDA) { |
240 | transformer_encoder_layer_test_helper( |
241 | /*is_cuda=*/true, /*use_callable_activation=*/false); |
242 | transformer_encoder_layer_test_helper( |
243 | /*is_cuda=*/true, /*use_callable_activation=*/true); |
244 | } |
245 | |
246 | void transformer_decoder_layer_test_helper( |
247 | bool is_cuda, |
248 | bool use_callable_activation) { |
249 | torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; |
250 | torch::TensorOptions tensor_options = |
251 | torch::TensorOptions().dtype(torch::kFloat32).device(device); |
252 | |
253 | TransformerDecoderLayer model = |
254 | get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>( |
255 | tensor_options, use_callable_activation); |
256 | |
257 | // deterministic input |
258 | torch::Tensor decoder_input = |
259 | torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
260 | torch::Tensor memory_input = |
261 | torch::tensor({{{60, 70, 80, 90}}}, tensor_options); |
262 | torch::Tensor result = model(decoder_input, memory_input).detach(); |
263 | torch::Tensor ref_output = torch::tensor( |
264 | {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options); |
265 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
266 | ASSERT_TRUE(torch::allclose( |
267 | result, |
268 | ref_output, |
269 | 1e-7, |
270 | 1e-5, |
271 | /*equal_nan=*/true)); |
272 | |
273 | // deterministic input |
274 | decoder_input = |
275 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
276 | memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); |
277 | result = model(decoder_input, memory_input).detach(); |
278 | ref_output = torch::tensor( |
279 | {{{2.422245, 0.051716, -0.606338, -0.024756}}, |
280 | {{2.422245, 0.051716, -0.606338, -0.024756}}}, |
281 | tensor_options); |
282 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
283 | ASSERT_TRUE(torch::allclose( |
284 | result, |
285 | ref_output, |
286 | 1e-7, |
287 | 1e-5, |
288 | /*equal_nan=*/true)); |
289 | |
290 | // deterministic input |
291 | decoder_input = |
292 | torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); |
293 | memory_input = |
294 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
295 | result = model(decoder_input, memory_input).detach(); |
296 | ref_output = torch::tensor( |
297 | {{{2.343536, 0.085561, -0.654954, 0.074991}}, |
298 | {{2.343536, 0.085561, -0.654954, 0.074991}}}, |
299 | tensor_options); |
300 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
301 | ASSERT_TRUE(torch::allclose( |
302 | result, |
303 | ref_output, |
304 | 1e-7, |
305 | 1e-5, |
306 | /*equal_nan=*/true)); |
307 | |
308 | // deterministic input |
309 | decoder_input = torch::tensor( |
310 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
311 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
312 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
313 | tensor_options); |
314 | memory_input = torch::tensor( |
315 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
316 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
317 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
318 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
319 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
320 | tensor_options); |
321 | result = model(decoder_input, memory_input).detach(); |
322 | ref_output = torch::tensor( |
323 | {{{2.430065, 0.027862, -0.601136, -0.073096}, |
324 | {2.431935, 0.028907, -0.599809, -0.072488}}, |
325 | {{2.428457, 0.027053, -0.602275, -0.073462}, |
326 | {2.431970, 0.029387, -0.599789, -0.071621}}, |
327 | {{2.431934, 0.028196, -0.599802, -0.073809}, |
328 | {2.432306, 0.028858, -0.599542, -0.072846}}}, |
329 | tensor_options); |
330 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
331 | ASSERT_TRUE(torch::allclose( |
332 | result, |
333 | ref_output, |
334 | 1e-7, |
335 | 1e-5, |
336 | /*equal_nan=*/true)); |
337 | |
338 | // key_padding_mask |
339 | torch::Tensor t_mask = {}; |
340 | torch::Tensor m_mask = {}; |
341 | torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; |
342 | result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) |
343 | .detach(); |
344 | ref_output = torch::tensor( |
345 | {{{2.430065, 0.027862, -0.601136, -0.073096}, |
346 | {2.431935, 0.028907, -0.599809, -0.072488}}, |
347 | {{2.428457, 0.027053, -0.602275, -0.073462}, |
348 | {2.431970, 0.029387, -0.599789, -0.071621}}, |
349 | {{2.431934, 0.028196, -0.599802, -0.073809}, |
350 | {2.432306, 0.028858, -0.599542, -0.072846}}}, |
351 | tensor_options); |
352 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
353 | ASSERT_TRUE(torch::allclose( |
354 | result, |
355 | ref_output, |
356 | 1e-7, |
357 | 1e-5, |
358 | /*equal_nan=*/true)); |
359 | |
360 | // key_padding_mask |
361 | key_padding_mask[0][2] = 1; |
362 | key_padding_mask[1][1] = 1; |
363 | key_padding_mask[1][2] = 1; |
364 | result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) |
365 | .detach(); |
366 | ref_output = torch::tensor( |
367 | {{{2.430025, 0.027643, -0.601164, -0.073476}, |
368 | {2.4323, 0.029375, -0.599553, -0.071881}}, |
369 | {{2.428523, 0.026838, -0.602226, -0.07391}, |
370 | {2.432634, 0.029842, -0.599318, -0.071253}}, |
371 | {{2.432278, 0.028152, -0.599555, -0.074139}, |
372 | {2.432659, 0.029244, -0.599294, -0.072382}}}, |
373 | tensor_options); |
374 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
375 | ASSERT_TRUE(torch::allclose( |
376 | result, |
377 | ref_output, |
378 | 1e-7, |
379 | 1e-5, |
380 | /*equal_nan=*/true)); |
381 | |
382 | // memory_key_padding_mask |
383 | torch::Tensor t_key_padding_mask = {}; |
384 | key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1; |
385 | result = model( |
386 | decoder_input, |
387 | memory_input, |
388 | t_mask, |
389 | m_mask, |
390 | t_key_padding_mask, |
391 | key_padding_mask) |
392 | .detach(); |
393 | ref_output = torch::tensor( |
394 | {{{2.430065, 0.027862, -0.601136, -0.073096}, |
395 | {2.431935, 0.028907, -0.599809, -0.072488}}, |
396 | {{2.428457, 0.027053, -0.602275, -0.073462}, |
397 | {2.431970, 0.029387, -0.599789, -0.071621}}, |
398 | {{2.431934, 0.028196, -0.599802, -0.073809}, |
399 | {2.432306, 0.028858, -0.599542, -0.072846}}}, |
400 | tensor_options); |
401 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
402 | ASSERT_TRUE(torch::allclose( |
403 | result, |
404 | ref_output, |
405 | 1e-7, |
406 | 1e-5, |
407 | /*equal_nan=*/true)); |
408 | |
409 | // memory_key_padding_mask |
410 | key_padding_mask[0][4] = 1; |
411 | key_padding_mask[1][3] = 1; |
412 | key_padding_mask[1][4] = 1; |
413 | result = model( |
414 | decoder_input, |
415 | memory_input, |
416 | t_mask, |
417 | m_mask, |
418 | t_key_padding_mask, |
419 | key_padding_mask) |
420 | .detach(); |
421 | ref_output = torch::tensor( |
422 | {{{2.429757, 0.027358, -0.601351, -0.073816}, |
423 | {2.432692, 0.028583, -0.599263, -0.073634}}, |
424 | {{2.428247, 0.02662, -0.602419, -0.074123}, |
425 | {2.432657, 0.029055, -0.599293, -0.072732}}, |
426 | {{2.431515, 0.027687, -0.600096, -0.074459}, |
427 | {2.433075, 0.028543, -0.598987, -0.073985}}}, |
428 | tensor_options); |
429 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
430 | ASSERT_TRUE(torch::allclose( |
431 | result, |
432 | ref_output, |
433 | 1e-7, |
434 | 1e-5, |
435 | /*equal_nan=*/true)); |
436 | } |
437 | |
438 | TEST_F(TransformerTest, TransformerDecoderLayer) { |
439 | transformer_decoder_layer_test_helper( |
440 | /*is_cuda=*/false, /*use_callable_activation=*/false); |
441 | transformer_decoder_layer_test_helper( |
442 | /*is_cuda=*/false, /*use_callable_activation=*/true); |
443 | } |
444 | |
445 | TEST_F(TransformerTest, TransformerDecoderLayer_CUDA) { |
446 | transformer_decoder_layer_test_helper( |
447 | /*is_cuda=*/true, /*use_callable_activation=*/false); |
448 | transformer_decoder_layer_test_helper( |
449 | /*is_cuda=*/true, /*use_callable_activation=*/true); |
450 | } |
451 | |
452 | void transformer_decoder_layer_test_helper_gelu( |
453 | bool is_cuda, |
454 | bool use_callable_activation) { |
455 | torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; |
456 | torch::TensorOptions tensor_options = |
457 | torch::TensorOptions().dtype(torch::kFloat32).device(device); |
458 | |
459 | TransformerDecoderLayer model = |
460 | get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>( |
461 | tensor_options, use_callable_activation); |
462 | if (use_callable_activation) { |
463 | model.get()->options.activation( |
464 | [&](const torch::Tensor& t) { return torch::nn::functional::gelu(t); }); |
465 | } else { |
466 | model.get()->options.activation(torch::kGELU); |
467 | } |
468 | |
469 | // deterministic input |
470 | torch::Tensor decoder_input = |
471 | torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
472 | torch::Tensor memory_input = |
473 | torch::tensor({{{60, 70, 80, 90}}}, tensor_options); |
474 | torch::Tensor result = model(decoder_input, memory_input).detach(); |
475 | torch::Tensor ref_output = torch::tensor( |
476 | {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options); |
477 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
478 | ASSERT_TRUE(torch::allclose( |
479 | result, |
480 | ref_output, |
481 | 1e-7, |
482 | 1e-5, |
483 | /*equal_nan=*/true)); |
484 | |
485 | // deterministic input |
486 | decoder_input = |
487 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
488 | memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); |
489 | result = model(decoder_input, memory_input).detach(); |
490 | ref_output = torch::tensor( |
491 | {{{2.415448, 0.054389, -0.610932, -0.0156613}}, |
492 | {{2.415448, 0.054389, -0.610932, -0.0156613}}}, |
493 | tensor_options); |
494 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
495 | ASSERT_TRUE(torch::allclose( |
496 | result, |
497 | ref_output, |
498 | 1e-7, |
499 | 1e-5, |
500 | /*equal_nan=*/true)); |
501 | |
502 | // deterministic input |
503 | decoder_input = |
504 | torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); |
505 | memory_input = |
506 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
507 | result = model(decoder_input, memory_input).detach(); |
508 | ref_output = torch::tensor( |
509 | {{{2.338531, 0.087709, -0.65776, 0.080646}}, |
510 | {{2.338531, 0.087709, -0.65776, 0.080646}}}, |
511 | tensor_options); |
512 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
513 | ASSERT_TRUE(torch::allclose( |
514 | result, |
515 | ref_output, |
516 | 1e-7, |
517 | 1e-5, |
518 | /*equal_nan=*/true)); |
519 | |
520 | // deterministic input |
521 | decoder_input = torch::tensor( |
522 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
523 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
524 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
525 | tensor_options); |
526 | memory_input = torch::tensor( |
527 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
528 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
529 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
530 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
531 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
532 | tensor_options); |
533 | result = model(decoder_input, memory_input).detach(); |
534 | ref_output = torch::tensor( |
535 | {{{2.42049104, 0.03443088, -0.60793706, -0.05436271}, |
536 | {2.42210631, 0.03546578, -0.60679895, -0.05357488}}, |
537 | {{2.41907674, 0.0336104, -0.60892977, -0.05490462}, |
538 | {2.42216881, 0.03586554, -0.6067524, -0.05289126}}, |
539 | {{2.42205716, 0.03488046, -0.60683681, -0.05460596}, |
540 | {2.42240309, 0.0354595, -0.60659063, -0.05378816}}}, |
541 | tensor_options); |
542 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
543 | ASSERT_TRUE(torch::allclose( |
544 | result, |
545 | ref_output, |
546 | 1e-7, |
547 | 1e-5, |
548 | /*equal_nan=*/true)); |
549 | } |
550 | |
551 | TEST_F(TransformerTest, TransformerDecoderLayer_gelu) { |
552 | transformer_decoder_layer_test_helper_gelu( |
553 | /*is_cuda=*/false, /*use_callable_activation=*/false); |
554 | transformer_decoder_layer_test_helper_gelu( |
555 | /*is_cuda=*/false, /*use_callable_activation=*/true); |
556 | } |
557 | |
558 | TEST_F(TransformerTest, TransformerDecoderLayer_gelu_CUDA) { |
559 | transformer_decoder_layer_test_helper_gelu( |
560 | /*is_cuda=*/true, /*use_callable_activation=*/false); |
561 | transformer_decoder_layer_test_helper_gelu( |
562 | /*is_cuda=*/true, /*use_callable_activation=*/true); |
563 | } |
564 | |
565 | void transformer_encoder_test_helper( |
566 | bool is_cuda, |
567 | bool use_callable_activation) { |
568 | // this is a deterministic test for TransformerEncoderLayer |
569 | torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; |
570 | torch::TensorOptions tensor_options = |
571 | torch::TensorOptions().dtype(torch::kFloat32).device(device); |
572 | |
573 | TransformerEncoderLayer encoder_layer = |
574 | get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>( |
575 | tensor_options, use_callable_activation); |
576 | |
577 | TransformerEncoder model(TransformerEncoderOptions(encoder_layer, 1)); |
578 | if (is_cuda) { |
579 | model->to(torch::kCUDA); |
580 | } |
581 | |
582 | torch::Tensor encoder_input = torch::tensor( |
583 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
584 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
585 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
586 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
587 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
588 | tensor_options); |
589 | torch::Tensor result = model(encoder_input).detach(); |
590 | torch::Tensor ref_output = torch::tensor( |
591 | {{{2.428589, 0.020835, -0.602055, -0.085249}, |
592 | {2.427987, 0.021213, -0.602496, -0.084103}}, |
593 | {{2.424689, 0.019155, -0.604793, -0.085672}, |
594 | {2.413863, 0.022211, -0.612486, -0.072490}}, |
595 | {{2.433774, 0.021598, -0.598343, -0.087548}, |
596 | {2.425104, 0.019748, -0.604515, -0.084839}}, |
597 | {{2.436185, 0.022682, -0.596625, -0.087261}, |
598 | {2.433556, 0.021891, -0.598509, -0.086832}}, |
599 | {{2.416246, 0.017512, -0.610712, -0.082961}, |
600 | {2.422901, 0.024187, -0.606178, -0.074929}}}, |
601 | tensor_options); |
602 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
603 | ASSERT_TRUE( |
604 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
605 | |
606 | // all 0 values are NOT masked |
607 | torch::Tensor mask = torch::zeros({2, 5}, tensor_options) == 1; |
608 | result = model( |
609 | encoder_input, |
610 | /*src_mask=*/torch::Tensor{}, |
611 | /*src_key_padding_mask=*/mask) |
612 | .detach(); |
613 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
614 | ASSERT_TRUE( |
615 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
616 | |
617 | // mask with 0s and 1s |
618 | mask[0][1] = 1; |
619 | mask[1][3] = 1; |
620 | mask[1][4] = 1; |
621 | result = model( |
622 | encoder_input, |
623 | /*src_mask=*/torch::Tensor{}, |
624 | /*src_key_padding_mask=*/mask) |
625 | .detach(); |
626 | ref_output = torch::tensor( |
627 | {{{2.429026, 0.020793, -0.601741, -0.085642}, |
628 | {2.428811, 0.021445, -0.601912, -0.084252}}, |
629 | {{2.425009, 0.019155, -0.604566, -0.085899}, |
630 | {2.415408, 0.02249, -0.611415, -0.073}}, |
631 | {{2.434199, 0.021682, -0.598039, -0.087699}, |
632 | {2.42598, 0.019941, -0.603896, -0.085091}}, |
633 | {{2.436457, 0.022736, -0.59643, -0.08736}, |
634 | {2.434021, 0.022093, -0.598179, -0.08679}}, |
635 | {{2.416531, 0.017498, -0.610513, -0.083181}, |
636 | {2.4242, 0.024653, -0.605266, -0.074959}}}, |
637 | tensor_options); |
638 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
639 | ASSERT_TRUE( |
640 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
641 | |
642 | // test case 2, multiple layers no norm |
643 | model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 2)); |
644 | if (is_cuda) { |
645 | model->to(torch::kCUDA); |
646 | } |
647 | result = model( |
648 | encoder_input, |
649 | /*src_mask=*/torch::Tensor{}, |
650 | /*src_key_padding_mask=*/mask) |
651 | .detach(); |
652 | ref_output = torch::tensor( |
653 | {{{2.419051, 0.017446, -0.608738, -0.085003}, |
654 | {2.419102, 0.017452, -0.608703, -0.085026}}, |
655 | {{2.419043, 0.017445, -0.608744, -0.084999}, |
656 | {2.419052, 0.017446, -0.608738, -0.085004}}, |
657 | {{2.419067, 0.017448, -0.608727, -0.085010}, |
658 | {2.419098, 0.017452, -0.608706, -0.085024}}, |
659 | {{2.419072, 0.017449, -0.608724, -0.085012}, |
660 | {2.419119, 0.017455, -0.608691, -0.085034}}, |
661 | {{2.419019, 0.017442, -0.608761, -0.084989}, |
662 | {2.419075, 0.017449, -0.608722, -0.085014}}}, |
663 | tensor_options); |
664 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
665 | ASSERT_TRUE( |
666 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
667 | |
668 | model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 6)); |
669 | if (is_cuda) { |
670 | model->to(torch::kCUDA); |
671 | } |
672 | result = model( |
673 | encoder_input, |
674 | /*src_mask=*/torch::Tensor{}, |
675 | /*src_key_padding_mask=*/mask) |
676 | .detach(); |
677 | ref_output = torch::tensor( |
678 | {{{2.419101, 0.017453, -0.608703, -0.085025}, |
679 | {2.419101, 0.017453, -0.608704, -0.085025}}, |
680 | {{2.419101, 0.017453, -0.608703, -0.085025}, |
681 | {2.419101, 0.017453, -0.608704, -0.085025}}, |
682 | {{2.419101, 0.017453, -0.608703, -0.085025}, |
683 | {2.419101, 0.017453, -0.608704, -0.085025}}, |
684 | {{2.419101, 0.017453, -0.608703, -0.085025}, |
685 | {2.419101, 0.017453, -0.608704, -0.085025}}, |
686 | {{2.419101, 0.017453, -0.608703, -0.085025}, |
687 | {2.419101, 0.017453, -0.608704, -0.085025}}}, |
688 | tensor_options); |
689 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
690 | ASSERT_TRUE( |
691 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
692 | |
693 | // test case 3, multiple layers with norm |
694 | LayerNorm norm(LayerNormOptions({encoder_layer.get()->options.d_model()})); |
695 | model = TransformerEncoder( |
696 | TransformerEncoderOptions(encoder_layer, 2).norm(AnyModule(norm))); |
697 | if (is_cuda) { |
698 | model->to(torch::kCUDA); |
699 | } |
700 | result = model( |
701 | encoder_input, |
702 | /*src_mask=*/torch::Tensor{}, |
703 | /*src_key_padding_mask=*/mask) |
704 | .detach(); |
705 | ref_output = torch::tensor( |
706 | {{{1.695949, -0.357635, -0.893077, -0.445238}, |
707 | {1.695955, -0.357639, -0.893050, -0.445266}}, |
708 | {{1.695948, -0.357634, -0.893082, -0.445233}, |
709 | {1.695950, -0.357635, -0.893077, -0.445238}}, |
710 | {{1.695951, -0.357636, -0.893069, -0.445246}, |
711 | {1.695955, -0.357639, -0.893052, -0.445264}}, |
712 | {{1.695952, -0.357636, -0.893066, -0.445249}, |
713 | {1.695957, -0.357641, -0.893041, -0.445276}}, |
714 | {{1.695946, -0.357632, -0.893095, -0.445220}, |
715 | {1.695952, -0.357637, -0.893065, -0.445251}}}, |
716 | tensor_options); |
717 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
718 | ASSERT_TRUE( |
719 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
720 | |
721 | model = TransformerEncoder( |
722 | TransformerEncoderOptions(encoder_layer, 6).norm(AnyModule(norm))); |
723 | if (is_cuda) { |
724 | model->to(torch::kCUDA); |
725 | } |
726 | result = model( |
727 | encoder_input, |
728 | /*src_mask=*/torch::Tensor{}, |
729 | /*src_key_padding_mask=*/mask) |
730 | .detach(); |
731 | ref_output = torch::tensor( |
732 | {{{1.695955, -0.357639, -0.893051, -0.445265}, |
733 | {1.695955, -0.357639, -0.893051, -0.445265}}, |
734 | {{1.695955, -0.357639, -0.893051, -0.445265}, |
735 | {1.695955, -0.357639, -0.893051, -0.445265}}, |
736 | {{1.695955, -0.357639, -0.893051, -0.445265}, |
737 | {1.695955, -0.357639, -0.893051, -0.445265}}, |
738 | {{1.695955, -0.357639, -0.893051, -0.445265}, |
739 | {1.695955, -0.357639, -0.893051, -0.445265}}, |
740 | {{1.695955, -0.357639, -0.893051, -0.445265}, |
741 | {1.695955, -0.357639, -0.893051, -0.445265}}}, |
742 | tensor_options); |
743 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
744 | ASSERT_TRUE( |
745 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
746 | } |
747 | |
748 | TEST_F(TransformerTest, TransformerEncoder) { |
749 | transformer_encoder_test_helper( |
750 | /*is_cuda=*/false, /*use_callable_activation=*/false); |
751 | transformer_encoder_test_helper( |
752 | /*is_cuda=*/false, /*use_callable_activation=*/true); |
753 | } |
754 | |
755 | TEST_F(TransformerTest, TransformerEncoder_CUDA) { |
756 | transformer_encoder_test_helper( |
757 | /*is_cuda=*/true, /*use_callable_activation=*/false); |
758 | transformer_encoder_test_helper( |
759 | /*is_cuda=*/true, /*use_callable_activation=*/true); |
760 | } |
761 | |
762 | TEST_F(TransformerTest, PrettyPrintTransformerEncoderLayer) { |
763 | ASSERT_EQ( |
764 | c10::str(TransformerEncoderLayer(4, 2)), |
765 | "torch::nn::TransformerEncoderLayerImpl(\n" |
766 | " (self_attn): torch::nn::MultiheadAttention(\n" |
767 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
768 | " )\n" |
769 | " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n" |
770 | " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" |
771 | " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n" |
772 | " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
773 | " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
774 | " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" |
775 | " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" |
776 | ")" ); |
777 | } |
778 | |
779 | TEST_F(TransformerTest, PrettyPrintTransformerEncoder) { |
780 | LayerNorm norm = LayerNorm(LayerNormOptions({4})); |
781 | TransformerEncoderOptions options( |
782 | TransformerEncoderOptions(TransformerEncoderLayerOptions(4, 2), 2) |
783 | .norm(AnyModule(norm))); |
784 | ASSERT_EQ( |
785 | c10::str(TransformerEncoder(options)), |
786 | "torch::nn::TransformerEncoderImpl(\n" |
787 | " (layers): torch::nn::ModuleList(\n" |
788 | " (0): torch::nn::TransformerEncoderLayerImpl(\n" |
789 | " (self_attn): torch::nn::MultiheadAttention(\n" |
790 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
791 | " )\n" |
792 | " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n" |
793 | " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" |
794 | " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n" |
795 | " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
796 | " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
797 | " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" |
798 | " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" |
799 | " )\n" |
800 | " (1): torch::nn::TransformerEncoderLayerImpl(\n" |
801 | " (self_attn): torch::nn::MultiheadAttention(\n" |
802 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
803 | " )\n" |
804 | " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n" |
805 | " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" |
806 | " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n" |
807 | " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
808 | " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
809 | " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" |
810 | " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" |
811 | " )\n" |
812 | " )\n" |
813 | " (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
814 | ")" ); |
815 | } |
816 | |
817 | TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) { |
818 | ASSERT_EQ( |
819 | c10::str(TransformerDecoderLayer(4, 2)), |
820 | "torch::nn::TransformerDecoderLayerImpl(\n" |
821 | " (self_attn): torch::nn::MultiheadAttention(\n" |
822 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
823 | " )\n" |
824 | " (multihead_attn): torch::nn::MultiheadAttention(\n" |
825 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
826 | " )\n" |
827 | " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n" |
828 | " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" |
829 | " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n" |
830 | " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
831 | " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
832 | " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
833 | " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" |
834 | " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" |
835 | " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n" |
836 | ")" ); |
837 | } |
838 | |
839 | void transformer_decoder_test_helper( |
840 | bool is_cuda, |
841 | bool use_callable_activation) { |
842 | // this is a deterministic test for TransformerDecoder |
843 | torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; |
844 | torch::TensorOptions tensor_options = |
845 | torch::TensorOptions().dtype(torch::kFloat32).device(device); |
846 | |
847 | TransformerDecoderLayer decoder_layer = |
848 | get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>( |
849 | tensor_options, use_callable_activation); |
850 | |
851 | TransformerDecoder model(TransformerDecoderOptions(decoder_layer, 1)); |
852 | if (is_cuda) { |
853 | model->to(torch::kCUDA); |
854 | } |
855 | |
856 | torch::Tensor decoder_input = |
857 | torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
858 | torch::Tensor memory_input = |
859 | torch::tensor({{{60, 70, 80, 90}}}, tensor_options); |
860 | torch::Tensor result = model(decoder_input, memory_input).detach(); |
861 | torch::Tensor ref_output = torch::tensor( |
862 | {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options); |
863 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
864 | ASSERT_TRUE(torch::allclose( |
865 | result, |
866 | ref_output, |
867 | 1e-7, |
868 | 1e-5, |
869 | /*equal_nan=*/true)); |
870 | |
871 | // deterministic input |
872 | decoder_input = |
873 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
874 | memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); |
875 | result = model(decoder_input, memory_input).detach(); |
876 | ref_output = torch::tensor( |
877 | {{{2.422245, 0.051716, -0.606338, -0.024756}}, |
878 | {{2.422245, 0.051716, -0.606338, -0.024756}}}, |
879 | tensor_options); |
880 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
881 | ASSERT_TRUE(torch::allclose( |
882 | result, |
883 | ref_output, |
884 | 1e-7, |
885 | 1e-5, |
886 | /*equal_nan=*/true)); |
887 | |
888 | // deterministic input |
889 | decoder_input = |
890 | torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); |
891 | memory_input = |
892 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
893 | result = model(decoder_input, memory_input).detach(); |
894 | ref_output = torch::tensor( |
895 | {{{2.343536, 0.085561, -0.654954, 0.074991}}, |
896 | {{2.343536, 0.085561, -0.654954, 0.074991}}}, |
897 | tensor_options); |
898 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
899 | ASSERT_TRUE(torch::allclose( |
900 | result, |
901 | ref_output, |
902 | 1e-7, |
903 | 1e-5, |
904 | /*equal_nan=*/true)); |
905 | |
906 | // deterministic input |
907 | decoder_input = torch::tensor( |
908 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
909 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
910 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
911 | tensor_options); |
912 | memory_input = torch::tensor( |
913 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
914 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
915 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
916 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
917 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
918 | tensor_options); |
919 | result = model(decoder_input, memory_input).detach(); |
920 | ref_output = torch::tensor( |
921 | {{{2.430065, 0.027862, -0.601136, -0.073096}, |
922 | {2.431935, 0.028907, -0.599809, -0.072488}}, |
923 | {{2.428457, 0.027053, -0.602275, -0.073462}, |
924 | {2.431970, 0.029387, -0.599789, -0.071621}}, |
925 | {{2.431934, 0.028196, -0.599802, -0.073809}, |
926 | {2.432306, 0.028858, -0.599542, -0.072846}}}, |
927 | tensor_options); |
928 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
929 | ASSERT_TRUE(torch::allclose( |
930 | result, |
931 | ref_output, |
932 | 1e-7, |
933 | 1e-5, |
934 | /*equal_nan=*/true)); |
935 | |
936 | // key_padding_mask |
937 | torch::Tensor t_mask = {}; |
938 | torch::Tensor m_mask = {}; |
939 | torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; |
940 | result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) |
941 | .detach(); |
942 | ref_output = torch::tensor( |
943 | {{{2.430065, 0.027862, -0.601136, -0.073096}, |
944 | {2.431935, 0.028907, -0.599809, -0.072488}}, |
945 | {{2.428457, 0.027053, -0.602275, -0.073462}, |
946 | {2.431970, 0.029387, -0.599789, -0.071621}}, |
947 | {{2.431934, 0.028196, -0.599802, -0.073809}, |
948 | {2.432306, 0.028858, -0.599542, -0.072846}}}, |
949 | tensor_options); |
950 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
951 | ASSERT_TRUE(torch::allclose( |
952 | result, |
953 | ref_output, |
954 | 1e-7, |
955 | 1e-5, |
956 | /*equal_nan=*/true)); |
957 | |
958 | // key_padding_mask |
959 | key_padding_mask[0][2] = 1; |
960 | key_padding_mask[1][1] = 1; |
961 | key_padding_mask[1][2] = 1; |
962 | result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask) |
963 | .detach(); |
964 | ref_output = torch::tensor( |
965 | {{{2.430025, 0.027643, -0.601164, -0.073476}, |
966 | {2.4323, 0.029375, -0.599553, -0.071881}}, |
967 | {{2.428523, 0.026838, -0.602226, -0.07391}, |
968 | {2.432634, 0.029842, -0.599318, -0.071253}}, |
969 | {{2.432278, 0.028152, -0.599555, -0.074139}, |
970 | {2.432659, 0.029244, -0.599294, -0.072382}}}, |
971 | tensor_options); |
972 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
973 | ASSERT_TRUE(torch::allclose( |
974 | result, |
975 | ref_output, |
976 | 1e-7, |
977 | 1e-5, |
978 | /*equal_nan=*/true)); |
979 | |
980 | // memory_key_padding_mask |
981 | torch::Tensor t_key_padding_mask = {}; |
982 | key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1; |
983 | result = model( |
984 | decoder_input, |
985 | memory_input, |
986 | t_mask, |
987 | m_mask, |
988 | t_key_padding_mask, |
989 | key_padding_mask) |
990 | .detach(); |
991 | ref_output = torch::tensor( |
992 | {{{2.430065, 0.027862, -0.601136, -0.073096}, |
993 | {2.431935, 0.028907, -0.599809, -0.072488}}, |
994 | {{2.428457, 0.027053, -0.602275, -0.073462}, |
995 | {2.431970, 0.029387, -0.599789, -0.071621}}, |
996 | {{2.431934, 0.028196, -0.599802, -0.073809}, |
997 | {2.432306, 0.028858, -0.599542, -0.072846}}}, |
998 | tensor_options); |
999 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1000 | ASSERT_TRUE(torch::allclose( |
1001 | result, |
1002 | ref_output, |
1003 | 1e-7, |
1004 | 1e-5, |
1005 | /*equal_nan=*/true)); |
1006 | |
1007 | // memory_key_padding_mask |
1008 | key_padding_mask[0][4] = 1; |
1009 | key_padding_mask[1][3] = 1; |
1010 | key_padding_mask[1][4] = 1; |
1011 | result = model( |
1012 | decoder_input, |
1013 | memory_input, |
1014 | t_mask, |
1015 | m_mask, |
1016 | t_key_padding_mask, |
1017 | key_padding_mask) |
1018 | .detach(); |
1019 | ref_output = torch::tensor( |
1020 | {{{2.429757, 0.027358, -0.601351, -0.073816}, |
1021 | {2.432692, 0.028583, -0.599263, -0.073634}}, |
1022 | {{2.428247, 0.02662, -0.602419, -0.074123}, |
1023 | {2.432657, 0.029055, -0.599293, -0.072732}}, |
1024 | {{2.431515, 0.027687, -0.600096, -0.074459}, |
1025 | {2.433075, 0.028543, -0.598987, -0.073985}}}, |
1026 | tensor_options); |
1027 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1028 | ASSERT_TRUE(torch::allclose( |
1029 | result, |
1030 | ref_output, |
1031 | 1e-7, |
1032 | 1e-5, |
1033 | /*equal_nan=*/true)); |
1034 | |
1035 | // multiple layers no norm |
1036 | model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 2)); |
1037 | if (is_cuda) { |
1038 | model->to(torch::kCUDA); |
1039 | } |
1040 | |
1041 | decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
1042 | memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); |
1043 | result = model(decoder_input, memory_input).detach(); |
1044 | ref_output = torch::tensor( |
1045 | {{{2.31316, 0.0950293, -0.671995, 0.102802}}}, tensor_options); |
1046 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1047 | ASSERT_TRUE(torch::allclose( |
1048 | result, |
1049 | ref_output, |
1050 | 1e-7, |
1051 | 1e-5, |
1052 | /*equal_nan=*/true)); |
1053 | |
1054 | // multiple layers no norm |
1055 | model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6)); |
1056 | if (is_cuda) { |
1057 | model->to(torch::kCUDA); |
1058 | } |
1059 | // deterministic input |
1060 | decoder_input = torch::tensor( |
1061 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
1062 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
1063 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
1064 | tensor_options); |
1065 | memory_input = torch::tensor( |
1066 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
1067 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
1068 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
1069 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
1070 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
1071 | tensor_options); |
1072 | result = model(decoder_input, memory_input).detach(); |
1073 | ref_output = torch::tensor( |
1074 | {{{2.42794, 0.026164, -0.60263, -0.0747591}, |
1075 | {2.43113, 0.0279516, -0.600376, -0.0736896}}, |
1076 | {{2.42794, 0.026164, -0.60263, -0.0747591}, |
1077 | {2.43113, 0.0279516, -0.600376, -0.0736896}}, |
1078 | {{2.42794, 0.026164, -0.60263, -0.0747591}, |
1079 | {2.43113, 0.0279516, -0.600376, -0.0736896}}}, |
1080 | tensor_options); |
1081 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1082 | ASSERT_TRUE(torch::allclose( |
1083 | result, |
1084 | ref_output, |
1085 | 1e-7, |
1086 | 1e-5, |
1087 | /*equal_nan=*/true)); |
1088 | |
1089 | // multiple layers with norm |
1090 | LayerNorm norm(LayerNormOptions({decoder_layer.get()->options.d_model()})); |
1091 | model = TransformerDecoder( |
1092 | TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm))); |
1093 | if (is_cuda) { |
1094 | model->to(torch::kCUDA); |
1095 | } |
1096 | |
1097 | decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
1098 | memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); |
1099 | result = model(decoder_input, memory_input).detach(); |
1100 | ref_output = torch::tensor( |
1101 | {{{1.66166, -0.326986, -1.01466, -0.320017}}}, tensor_options); |
1102 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1103 | ASSERT_TRUE(torch::allclose( |
1104 | result, |
1105 | ref_output, |
1106 | 1e-7, |
1107 | 1e-5, |
1108 | /*equal_nan=*/true)); |
1109 | |
1110 | // multiple layers with norm |
1111 | model = TransformerDecoder( |
1112 | TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm))); |
1113 | if (is_cuda) { |
1114 | model->to(torch::kCUDA); |
1115 | } |
1116 | // deterministic input |
1117 | decoder_input = torch::tensor( |
1118 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
1119 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
1120 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
1121 | tensor_options); |
1122 | memory_input = torch::tensor( |
1123 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
1124 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
1125 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
1126 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
1127 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
1128 | tensor_options); |
1129 | result = model(decoder_input, memory_input).detach(); |
1130 | ref_output = torch::tensor( |
1131 | {{{1.69559, -0.357291, -0.894741, -0.443553}, |
1132 | {1.69571, -0.357363, -0.894154, -0.444196}}, |
1133 | {{1.69559, -0.357291, -0.894741, -0.443553}, |
1134 | {1.69571, -0.357363, -0.894154, -0.444196}}, |
1135 | {{1.69559, -0.357291, -0.894741, -0.443553}, |
1136 | {1.69571, -0.357363, -0.894154, -0.444196}}}, |
1137 | tensor_options); |
1138 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1139 | ASSERT_TRUE(torch::allclose( |
1140 | result, |
1141 | ref_output, |
1142 | 1e-7, |
1143 | 1e-5, |
1144 | /*equal_nan=*/true)); |
1145 | |
1146 | // gelu activation test cases |
1147 | decoder_layer.get()->options.activation(torch::kGELU); |
1148 | model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 1)); |
1149 | if (is_cuda) { |
1150 | model->to(torch::kCUDA); |
1151 | } |
1152 | |
1153 | // deterministic input |
1154 | decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); |
1155 | memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); |
1156 | result = model(decoder_input, memory_input).detach(); |
1157 | ref_output = torch::tensor( |
1158 | {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options); |
1159 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1160 | ASSERT_TRUE(torch::allclose( |
1161 | result, |
1162 | ref_output, |
1163 | 1e-7, |
1164 | 1e-5, |
1165 | /*equal_nan=*/true)); |
1166 | |
1167 | // deterministic input |
1168 | decoder_input = |
1169 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
1170 | memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); |
1171 | result = model(decoder_input, memory_input).detach(); |
1172 | ref_output = torch::tensor( |
1173 | {{{2.415448, 0.054389, -0.610932, -0.0156613}}, |
1174 | {{2.415448, 0.054389, -0.610932, -0.0156613}}}, |
1175 | tensor_options); |
1176 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1177 | ASSERT_TRUE(torch::allclose( |
1178 | result, |
1179 | ref_output, |
1180 | 1e-7, |
1181 | 1e-5, |
1182 | /*equal_nan=*/true)); |
1183 | |
1184 | // deterministic input |
1185 | decoder_input = |
1186 | torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options); |
1187 | memory_input = |
1188 | torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options); |
1189 | result = model(decoder_input, memory_input).detach(); |
1190 | ref_output = torch::tensor( |
1191 | {{{2.338531, 0.087709, -0.65776, 0.080646}}, |
1192 | {{2.338531, 0.087709, -0.65776, 0.080646}}}, |
1193 | tensor_options); |
1194 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1195 | ASSERT_TRUE(torch::allclose( |
1196 | result, |
1197 | ref_output, |
1198 | 1e-7, |
1199 | 1e-5, |
1200 | /*equal_nan=*/true)); |
1201 | |
1202 | // deterministic input |
1203 | decoder_input = torch::tensor( |
1204 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
1205 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
1206 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
1207 | tensor_options); |
1208 | memory_input = torch::tensor( |
1209 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
1210 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
1211 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
1212 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
1213 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
1214 | tensor_options); |
1215 | result = model(decoder_input, memory_input).detach(); |
1216 | ref_output = torch::tensor( |
1217 | {{{2.42049104, 0.03443088, -0.60793706, -0.05436271}, |
1218 | {2.42210631, 0.03546578, -0.60679895, -0.05357488}}, |
1219 | {{2.41907674, 0.0336104, -0.60892977, -0.05490462}, |
1220 | {2.42216881, 0.03586554, -0.6067524, -0.05289126}}, |
1221 | {{2.42205716, 0.03488046, -0.60683681, -0.05460596}, |
1222 | {2.42240309, 0.0354595, -0.60659063, -0.05378816}}}, |
1223 | tensor_options); |
1224 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1225 | ASSERT_TRUE(torch::allclose( |
1226 | result, |
1227 | ref_output, |
1228 | 1e-7, |
1229 | 1e-5, |
1230 | /*equal_nan=*/true)); |
1231 | |
1232 | // Multiple layers no norm |
1233 | model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6)); |
1234 | if (is_cuda) { |
1235 | model->to(torch::kCUDA); |
1236 | } |
1237 | decoder_input = torch::tensor( |
1238 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
1239 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
1240 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
1241 | tensor_options); |
1242 | memory_input = torch::tensor( |
1243 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
1244 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
1245 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
1246 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
1247 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
1248 | tensor_options); |
1249 | result = model(decoder_input, memory_input).detach(); |
1250 | ref_output = torch::tensor( |
1251 | {{{2.41859, 0.0328114, -0.609269, -0.0560386}, |
1252 | {2.42138, 0.034598, -0.607316, -0.0546574}}, |
1253 | {{2.41859, 0.0328114, -0.609269, -0.0560386}, |
1254 | {2.42138, 0.034598, -0.607316, -0.0546574}}, |
1255 | {{2.41859, 0.0328114, -0.609269, -0.0560386}, |
1256 | {2.42138, 0.034598, -0.607316, -0.0546574}}}, |
1257 | tensor_options); |
1258 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1259 | ASSERT_TRUE(torch::allclose( |
1260 | result, |
1261 | ref_output, |
1262 | 1e-7, |
1263 | 1e-5, |
1264 | /*equal_nan=*/true)); |
1265 | |
1266 | // Multiple layers with norm |
1267 | norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()})); |
1268 | model = TransformerDecoder( |
1269 | TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm))); |
1270 | if (is_cuda) { |
1271 | model->to(torch::kCUDA); |
1272 | } |
1273 | |
1274 | decoder_input = torch::tensor( |
1275 | {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}}, |
1276 | {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}}, |
1277 | {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}}, |
1278 | tensor_options); |
1279 | memory_input = torch::tensor( |
1280 | {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}}, |
1281 | {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}}, |
1282 | {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}}, |
1283 | {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}}, |
1284 | {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}}, |
1285 | tensor_options); |
1286 | result = model(decoder_input, memory_input).detach(); |
1287 | ref_output = torch::tensor( |
1288 | {{{1.69298, -0.355163, -0.906375, -0.431439}, |
1289 | {1.69305, -0.355195, -0.906062, -0.431791}}, |
1290 | {{1.69298, -0.355163, -0.906375, -0.431439}, |
1291 | {1.69305, -0.355195, -0.906062, -0.431791}}, |
1292 | {{1.69298, -0.355163, -0.906375, -0.431439}, |
1293 | {1.69305, -0.355195, -0.906062, -0.431791}}}, |
1294 | tensor_options); |
1295 | ASSERT_EQ(result.sizes().size(), ref_output.sizes().size()); |
1296 | ASSERT_TRUE(torch::allclose( |
1297 | result, |
1298 | ref_output, |
1299 | 1e-7, |
1300 | 1e-5, |
1301 | /*equal_nan=*/true)); |
1302 | } |
1303 | |
1304 | TEST_F(TransformerTest, TransformerDecoder) { |
1305 | transformer_decoder_test_helper( |
1306 | /*is_cuda=*/false, /*use_callable_activation=*/false); |
1307 | transformer_decoder_test_helper( |
1308 | /*is_cuda=*/false, /*use_callable_activation=*/true); |
1309 | } |
1310 | |
1311 | TEST_F(TransformerTest, TransformerDecoder_CUDA) { |
1312 | transformer_decoder_test_helper( |
1313 | /*is_cuda=*/true, /*use_callable_activation=*/false); |
1314 | transformer_decoder_test_helper( |
1315 | /*is_cuda=*/true, /*use_callable_activation=*/true); |
1316 | } |
1317 | |
1318 | TEST_F(TransformerTest, PrettyPrintTransformerDecoder) { |
1319 | LayerNorm norm = LayerNorm(LayerNormOptions({4})); |
1320 | TransformerDecoderOptions options( |
1321 | TransformerDecoderOptions(TransformerDecoderLayerOptions(4, 2), 2) |
1322 | .norm(AnyModule(norm))); |
1323 | ASSERT_EQ( |
1324 | c10::str(TransformerDecoder(options)), |
1325 | "torch::nn::TransformerDecoderImpl(\n" |
1326 | " (layers): torch::nn::ModuleList(\n" |
1327 | " (0): torch::nn::TransformerDecoderLayerImpl(\n" |
1328 | " (self_attn): torch::nn::MultiheadAttention(\n" |
1329 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
1330 | " )\n" |
1331 | " (multihead_attn): torch::nn::MultiheadAttention(\n" |
1332 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
1333 | " )\n" |
1334 | " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n" |
1335 | " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1336 | " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n" |
1337 | " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
1338 | " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
1339 | " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
1340 | " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1341 | " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1342 | " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1343 | " )\n" |
1344 | " (1): torch::nn::TransformerDecoderLayerImpl(\n" |
1345 | " (self_attn): torch::nn::MultiheadAttention(\n" |
1346 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
1347 | " )\n" |
1348 | " (multihead_attn): torch::nn::MultiheadAttention(\n" |
1349 | " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n" |
1350 | " )\n" |
1351 | " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n" |
1352 | " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1353 | " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n" |
1354 | " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
1355 | " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
1356 | " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
1357 | " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1358 | " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1359 | " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n" |
1360 | " )\n" |
1361 | " )\n" |
1362 | " (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" |
1363 | ")" ); |
1364 | } |
1365 | |
1366 | void transformer_test_helper(bool is_cuda, bool use_callable_activation) { |
1367 | // this is a deterministic test for Transformere |
1368 | torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; |
1369 | torch::TensorOptions tensor_options = |
1370 | torch::TensorOptions().dtype(torch::kFloat32).device(device); |
1371 | |
1372 | // transformer created encoder/decoder |
1373 | auto options = TransformerOptions() |
1374 | .d_model(4) |
1375 | .nhead(2) |
1376 | .num_encoder_layers(2) |
1377 | .num_decoder_layers(1) |
1378 | .dim_feedforward(16) |
1379 | .dropout(0.0) |
1380 | .activation(torch::kReLU); |
1381 | if (use_callable_activation) { |
1382 | options.activation( |
1383 | [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); }); |
1384 | } |
1385 | Transformer model(options); |
1386 | |
1387 | set_parameter_to_constants<Transformer>(model, tensor_options); |
1388 | if (tensor_options.device() == torch::kCUDA) { |
1389 | model->to(torch::kCUDA); |
1390 | } |
1391 | |
1392 | // transformer with customized encoder/decoder |
1393 | LayerNorm enorm(LayerNormOptions({4})); |
1394 | TransformerEncoder encoder( |
1395 | TransformerEncoderOptions( |
1396 | TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), |
1397 | 2) |
1398 | .norm(AnyModule(enorm))); |
1399 | |
1400 | LayerNorm dnorm(LayerNormOptions({4})); |
1401 | TransformerDecoder decoder( |
1402 | TransformerDecoderOptions( |
1403 | TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), |
1404 | 1) |
1405 | .norm(AnyModule(dnorm))); |
1406 | |
1407 | Transformer model_cus(TransformerOptions() |
1408 | .d_model(4) |
1409 | .nhead(2) |
1410 | .custom_encoder(AnyModule(encoder)) |
1411 | .custom_decoder(AnyModule(decoder))); |
1412 | |
1413 | set_parameter_to_constants<Transformer>(model_cus, tensor_options); |
1414 | if (tensor_options.device() == torch::kCUDA) { |
1415 | model_cus->to(torch::kCUDA); |
1416 | } |
1417 | |
1418 | // test cases |
1419 | torch::Tensor src = torch::tensor( |
1420 | {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}}, |
1421 | {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}, |
1422 | {{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}}, |
1423 | tensor_options); |
1424 | |
1425 | torch::Tensor tgt = torch::tensor( |
1426 | {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}}, |
1427 | {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}}, |
1428 | tensor_options); |
1429 | |
1430 | torch::Tensor ref_output = torch::tensor( |
1431 | {{{2.695875, 0.347114, -0.044355, -0.549541}, |
1432 | {2.696091, 0.347015, -0.044770, -0.548522}}, |
1433 | {{2.695875, 0.347114, -0.044355, -0.549541}, |
1434 | {2.696091, 0.347015, -0.044770, -0.548522}}}, |
1435 | tensor_options); |
1436 | torch::Tensor result = model(src, tgt); |
1437 | torch::Tensor result_cus = model_cus(src, tgt); |
1438 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
1439 | ASSERT_TRUE(result.equal(result_cus)); |
1440 | ASSERT_TRUE( |
1441 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
1442 | |
1443 | torch::Tensor src_mask = |
1444 | Transformer::Impl::generate_square_subsequent_mask(src.size(0)) |
1445 | .to(tensor_options); |
1446 | ref_output = torch::tensor( |
1447 | {{{2.695875, 0.347114, -0.044355, -0.549541}, |
1448 | {2.696091, 0.347015, -0.044770, -0.548522}}, |
1449 | {{2.695875, 0.347114, -0.044355, -0.549541}, |
1450 | {2.696091, 0.347015, -0.044770, -0.548522}}}, |
1451 | tensor_options); |
1452 | result = model(src, tgt, src_mask); |
1453 | result_cus = model_cus(src, tgt, src_mask); |
1454 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
1455 | ASSERT_TRUE(result.equal(result_cus)); |
1456 | ASSERT_TRUE( |
1457 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
1458 | |
1459 | torch::Tensor tgt_key_padding_mask = |
1460 | torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1; |
1461 | tgt_key_padding_mask[0][0] = 1; |
1462 | tgt_key_padding_mask[1][1] = 1; |
1463 | ref_output = torch::tensor( |
1464 | {{{2.696114, 0.347004, -0.044813, -0.548417}, |
1465 | {2.696091, 0.347015, -0.044770, -0.548522}}, |
1466 | {{2.696114, 0.347004, -0.044813, -0.548417}, |
1467 | {2.696091, 0.347015, -0.044770, -0.548522}}}, |
1468 | tensor_options); |
1469 | result = model( |
1470 | src, |
1471 | tgt, |
1472 | src_mask, |
1473 | torch::Tensor(), |
1474 | torch::Tensor(), |
1475 | torch::Tensor(), |
1476 | tgt_key_padding_mask); |
1477 | result_cus = model_cus( |
1478 | src, |
1479 | tgt, |
1480 | src_mask, |
1481 | torch::Tensor(), |
1482 | torch::Tensor(), |
1483 | torch::Tensor(), |
1484 | tgt_key_padding_mask); |
1485 | ASSERT_EQ(result.sizes(), ref_output.sizes()); |
1486 | ASSERT_TRUE(result.equal(result_cus)); |
1487 | ASSERT_TRUE( |
1488 | torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); |
1489 | } |
1490 | |
1491 | TEST_F(TransformerTest, Transformer) { |
1492 | transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/false); |
1493 | transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/true); |
1494 | } |
1495 | |
1496 | TEST_F(TransformerTest, Transformer_CUDA) { |
1497 | transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/false); |
1498 | transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/true); |
1499 | } |
1500 | |
1501 | TEST_F(TransformerTest, TransformerArgsCorrectness) { |
1502 | Transformer model(TransformerOptions() |
1503 | .d_model(4) |
1504 | .nhead(2) |
1505 | .num_encoder_layers(2) |
1506 | .num_decoder_layers(1) |
1507 | .dim_feedforward(16) |
1508 | .dropout(0.0) |
1509 | .activation(torch::kReLU)); |
1510 | |
1511 | torch::Tensor src = torch::randn({2, 3, 4}); |
1512 | torch::Tensor tgt = torch::randn({3, 2, 4}); |
1513 | |
1514 | ASSERT_THROWS_WITH( |
1515 | model(src, tgt), "src and tgt should have equal batch size" ); |
1516 | |
1517 | tgt = torch::randn({2, 3, 3}); |
1518 | ASSERT_THROWS_WITH( |
1519 | model(src, tgt), "src and tgt should have same feature size as d_model" ); |
1520 | |
1521 | src = torch::randn({2, 3}); |
1522 | ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have 3 dimensions" ); |
1523 | } |
1524 | |