1#include <gtest/gtest.h>
2
3#include <torch/torch.h>
4
5#include <test/cpp/api/support.h>
6
7using namespace torch::nn;
8
9struct TransformerTest : torch::test::SeedingFixture {};
10
11// a generic function to set constants for parameters so we have fixed result
12// for deterministic test
13template <typename Model>
14void set_parameter_to_constants(
15 Model& model,
16 const torch::TensorOptions& tensor_options) {
17 torch::NoGradGuard guard;
18 for (auto& p : model->parameters()) {
19 auto sz = p.view(-1).size(0);
20 p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes())));
21 }
22}
23
24// a generic function to provide consistent encoder/decoder layer for all the
25// transformer tests
26template <typename T_LAYER, typename T_OPTIONS>
27T_LAYER get_a_test_layer(
28 const torch::TensorOptions& tensor_options,
29 bool use_callable_activation) {
30 int64_t d_model = 4;
31 int64_t nhead = 2;
32 int64_t dim_feedforward = 16;
33 double dropout = 0.0;
34
35 // activation is always ReLU here and it can be adjusted later depending on
36 // the usage
37 T_LAYER layer(T_OPTIONS(d_model, nhead)
38 .dim_feedforward(dim_feedforward)
39 .dropout(dropout));
40 if (tensor_options.device() == torch::kCUDA) {
41 layer->to(torch::kCUDA);
42 }
43 if (use_callable_activation) {
44 layer.get()->options.activation(
45 [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
46 }
47
48 // set constant weights of the model
49 set_parameter_to_constants<T_LAYER>(layer, tensor_options);
50
51 return layer;
52}
53
54void transformer_encoder_layer_test_helper(
55 bool is_cuda,
56 bool use_callable_activation) {
57 // this is a deterministic test for TransformerEncoderLayer
58 torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
59 torch::TensorOptions tensor_options =
60 torch::TensorOptions().dtype(torch::kFloat32).device(device);
61
62 TransformerEncoderLayer model =
63 get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
64 tensor_options, use_callable_activation);
65
66 // relu test case 1
67 torch::Tensor encoder_input =
68 torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
69 torch::Tensor result = model(encoder_input).detach();
70 torch::Tensor ref_output = torch::tensor(
71 {{{2.258703, 0.127985, -0.697881, 0.170862}}}, tensor_options);
72 ASSERT_EQ(result.sizes(), ref_output.sizes());
73 ASSERT_TRUE(
74 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
75
76 // all 0 values are NOT masked. This should't mask anything
77 torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1;
78 result = model(
79 encoder_input,
80 /*src_mask=*/torch::Tensor{},
81 /*src_key_padding_mask=*/mask)
82 .detach();
83 ASSERT_EQ(result.sizes(), ref_output.sizes());
84 ASSERT_TRUE(
85 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
86
87 // all 1 values are masked. Since there is only 1 input embedding this will
88 // result in nan.
89 mask = torch::tensor({{1}}, tensor_options) == 1;
90 result = model(
91 encoder_input,
92 /*src_mask=*/torch::Tensor{},
93 /*src_key_padding_mask=*/mask)
94 .detach();
95 ASSERT_TRUE(torch::isnan(result).all().item().to<bool>());
96
97 // relu test case 2
98 encoder_input =
99 torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
100 result = model(encoder_input).detach();
101 ref_output = torch::tensor(
102 {{{2.272644, 0.119035, -0.691669, 0.153486}},
103 {{2.272644, 0.119035, -0.691669, 0.153486}}},
104 tensor_options);
105 ASSERT_EQ(result.sizes(), ref_output.sizes());
106 ASSERT_TRUE(
107 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
108
109 // all 0 values are NOT masked
110 mask = torch::tensor({{0, 0}}, tensor_options) == 1;
111 result = model(
112 encoder_input,
113 /*src_mask=*/torch::Tensor{},
114 /*src_key_padding_mask=*/mask)
115 .detach();
116 ASSERT_EQ(result.sizes(), ref_output.sizes());
117 ASSERT_TRUE(
118 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
119
120 // mask with 1 and 0
121 mask = torch::tensor({{1, 0}}, tensor_options) == 1;
122 result = model(
123 encoder_input,
124 /*src_mask=*/torch::Tensor{},
125 /*src_key_padding_mask=*/mask)
126 .detach();
127 ref_output = torch::tensor(
128 {{{2.301516, 0.092249, -0.679101, 0.103088}},
129 {{2.301516, 0.092249, -0.679101, 0.103088}}},
130 tensor_options);
131 ASSERT_EQ(result.sizes(), ref_output.sizes());
132 ASSERT_TRUE(
133 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
134
135 // relu test case 3
136 encoder_input = torch::tensor(
137 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
138 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
139 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
140 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
141 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
142 tensor_options);
143 result = model(encoder_input).detach();
144 ref_output = torch::tensor(
145 {{{2.428589, 0.020835, -0.602055, -0.085249},
146 {2.427987, 0.021213, -0.602496, -0.084103}},
147 {{2.424689, 0.019155, -0.604793, -0.085672},
148 {2.413863, 0.022211, -0.612486, -0.072490}},
149 {{2.433774, 0.021598, -0.598343, -0.087548},
150 {2.425104, 0.019748, -0.604515, -0.084839}},
151 {{2.436185, 0.022682, -0.596625, -0.087261},
152 {2.433556, 0.021891, -0.598509, -0.086832}},
153 {{2.416246, 0.017512, -0.610712, -0.082961},
154 {2.422901, 0.024187, -0.606178, -0.074929}}},
155 tensor_options);
156 ASSERT_EQ(result.sizes(), ref_output.sizes());
157 ASSERT_TRUE(
158 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
159
160 // all 0 values are NOT masked
161 mask = torch::zeros({2, 5}, tensor_options) == 1;
162 result = model(
163 encoder_input,
164 /*src_mask=*/torch::Tensor{},
165 /*src_key_padding_mask=*/mask)
166 .detach();
167 ASSERT_EQ(result.sizes(), ref_output.sizes());
168 ASSERT_TRUE(
169 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
170
171 // mask with 0s and 1s
172 mask[0][1] = 1;
173 mask[1][3] = 1;
174 mask[1][4] = 1;
175 result = model(
176 encoder_input,
177 /*src_mask=*/torch::Tensor{},
178 /*src_key_padding_mask=*/mask)
179 .detach();
180 ref_output = torch::tensor(
181 {{{2.429026, 0.020793, -0.601741, -0.085642},
182 {2.428811, 0.021445, -0.601912, -0.084252}},
183 {{2.425009, 0.019155, -0.604566, -0.085899},
184 {2.415408, 0.02249, -0.611415, -0.073}},
185 {{2.434199, 0.021682, -0.598039, -0.087699},
186 {2.42598, 0.019941, -0.603896, -0.085091}},
187 {{2.436457, 0.022736, -0.59643, -0.08736},
188 {2.434021, 0.022093, -0.598179, -0.08679}},
189 {{2.416531, 0.017498, -0.610513, -0.083181},
190 {2.4242, 0.024653, -0.605266, -0.074959}}},
191 tensor_options);
192 ASSERT_EQ(result.sizes(), ref_output.sizes());
193 ASSERT_TRUE(
194 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
195
196 // gelu test case 1
197 model.get()->options.activation(torch::kGELU);
198 encoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
199 result = model(encoder_input).detach();
200 ref_output = torch::tensor(
201 {{{2.249815, 0.131006, -0.702199, 0.177868}}}, tensor_options);
202 ASSERT_EQ(result.sizes(), ref_output.sizes());
203 ASSERT_TRUE(
204 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
205
206 // gelu test case 2
207 encoder_input = torch::tensor(
208 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
209 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
210 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
211 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
212 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
213 tensor_options);
214 result = model(encoder_input);
215 ref_output = torch::tensor(
216 {{{2.42163188, 0.03227153, -0.60714219, -0.05908082},
217 {2.42151276, 0.03302179, -0.60722523, -0.05762651}},
218 {{2.41926761, 0.02974034, -0.60879519, -0.0621269},
219 {2.41626395, 0.03539356, -0.61087842, -0.04978623}},
220 {{2.42382808, 0.03218872, -0.6055963, -0.06073591},
221 {2.41983477, 0.03085259, -0.60840145, -0.06046414}},
222 {{2.42500749, 0.03328855, -0.60476388, -0.0595334},
223 {2.4237977, 0.03290575, -0.60561789, -0.05940082}},
224 {{2.41383916, 0.02686345, -0.61256377, -0.06380707},
225 {2.42000277, 0.03800944, -0.60824798, -0.04754947}}},
226 tensor_options);
227 ASSERT_EQ(result.sizes(), ref_output.sizes());
228 ASSERT_TRUE(
229 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
230}
231
232TEST_F(TransformerTest, TransformerEncoderLayer) {
233 transformer_encoder_layer_test_helper(
234 /*is_cuda=*/false, /*use_callable_activation=*/false);
235 transformer_encoder_layer_test_helper(
236 /*is_cuda=*/false, /*use_callable_activation=*/true);
237}
238
239TEST_F(TransformerTest, TransformerEncoderLayer_CUDA) {
240 transformer_encoder_layer_test_helper(
241 /*is_cuda=*/true, /*use_callable_activation=*/false);
242 transformer_encoder_layer_test_helper(
243 /*is_cuda=*/true, /*use_callable_activation=*/true);
244}
245
246void transformer_decoder_layer_test_helper(
247 bool is_cuda,
248 bool use_callable_activation) {
249 torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
250 torch::TensorOptions tensor_options =
251 torch::TensorOptions().dtype(torch::kFloat32).device(device);
252
253 TransformerDecoderLayer model =
254 get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
255 tensor_options, use_callable_activation);
256
257 // deterministic input
258 torch::Tensor decoder_input =
259 torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
260 torch::Tensor memory_input =
261 torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
262 torch::Tensor result = model(decoder_input, memory_input).detach();
263 torch::Tensor ref_output = torch::tensor(
264 {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
265 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
266 ASSERT_TRUE(torch::allclose(
267 result,
268 ref_output,
269 1e-7,
270 1e-5,
271 /*equal_nan=*/true));
272
273 // deterministic input
274 decoder_input =
275 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
276 memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
277 result = model(decoder_input, memory_input).detach();
278 ref_output = torch::tensor(
279 {{{2.422245, 0.051716, -0.606338, -0.024756}},
280 {{2.422245, 0.051716, -0.606338, -0.024756}}},
281 tensor_options);
282 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
283 ASSERT_TRUE(torch::allclose(
284 result,
285 ref_output,
286 1e-7,
287 1e-5,
288 /*equal_nan=*/true));
289
290 // deterministic input
291 decoder_input =
292 torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
293 memory_input =
294 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
295 result = model(decoder_input, memory_input).detach();
296 ref_output = torch::tensor(
297 {{{2.343536, 0.085561, -0.654954, 0.074991}},
298 {{2.343536, 0.085561, -0.654954, 0.074991}}},
299 tensor_options);
300 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
301 ASSERT_TRUE(torch::allclose(
302 result,
303 ref_output,
304 1e-7,
305 1e-5,
306 /*equal_nan=*/true));
307
308 // deterministic input
309 decoder_input = torch::tensor(
310 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
311 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
312 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
313 tensor_options);
314 memory_input = torch::tensor(
315 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
316 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
317 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
318 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
319 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
320 tensor_options);
321 result = model(decoder_input, memory_input).detach();
322 ref_output = torch::tensor(
323 {{{2.430065, 0.027862, -0.601136, -0.073096},
324 {2.431935, 0.028907, -0.599809, -0.072488}},
325 {{2.428457, 0.027053, -0.602275, -0.073462},
326 {2.431970, 0.029387, -0.599789, -0.071621}},
327 {{2.431934, 0.028196, -0.599802, -0.073809},
328 {2.432306, 0.028858, -0.599542, -0.072846}}},
329 tensor_options);
330 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
331 ASSERT_TRUE(torch::allclose(
332 result,
333 ref_output,
334 1e-7,
335 1e-5,
336 /*equal_nan=*/true));
337
338 // key_padding_mask
339 torch::Tensor t_mask = {};
340 torch::Tensor m_mask = {};
341 torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
342 result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
343 .detach();
344 ref_output = torch::tensor(
345 {{{2.430065, 0.027862, -0.601136, -0.073096},
346 {2.431935, 0.028907, -0.599809, -0.072488}},
347 {{2.428457, 0.027053, -0.602275, -0.073462},
348 {2.431970, 0.029387, -0.599789, -0.071621}},
349 {{2.431934, 0.028196, -0.599802, -0.073809},
350 {2.432306, 0.028858, -0.599542, -0.072846}}},
351 tensor_options);
352 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
353 ASSERT_TRUE(torch::allclose(
354 result,
355 ref_output,
356 1e-7,
357 1e-5,
358 /*equal_nan=*/true));
359
360 // key_padding_mask
361 key_padding_mask[0][2] = 1;
362 key_padding_mask[1][1] = 1;
363 key_padding_mask[1][2] = 1;
364 result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
365 .detach();
366 ref_output = torch::tensor(
367 {{{2.430025, 0.027643, -0.601164, -0.073476},
368 {2.4323, 0.029375, -0.599553, -0.071881}},
369 {{2.428523, 0.026838, -0.602226, -0.07391},
370 {2.432634, 0.029842, -0.599318, -0.071253}},
371 {{2.432278, 0.028152, -0.599555, -0.074139},
372 {2.432659, 0.029244, -0.599294, -0.072382}}},
373 tensor_options);
374 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
375 ASSERT_TRUE(torch::allclose(
376 result,
377 ref_output,
378 1e-7,
379 1e-5,
380 /*equal_nan=*/true));
381
382 // memory_key_padding_mask
383 torch::Tensor t_key_padding_mask = {};
384 key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
385 result = model(
386 decoder_input,
387 memory_input,
388 t_mask,
389 m_mask,
390 t_key_padding_mask,
391 key_padding_mask)
392 .detach();
393 ref_output = torch::tensor(
394 {{{2.430065, 0.027862, -0.601136, -0.073096},
395 {2.431935, 0.028907, -0.599809, -0.072488}},
396 {{2.428457, 0.027053, -0.602275, -0.073462},
397 {2.431970, 0.029387, -0.599789, -0.071621}},
398 {{2.431934, 0.028196, -0.599802, -0.073809},
399 {2.432306, 0.028858, -0.599542, -0.072846}}},
400 tensor_options);
401 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
402 ASSERT_TRUE(torch::allclose(
403 result,
404 ref_output,
405 1e-7,
406 1e-5,
407 /*equal_nan=*/true));
408
409 // memory_key_padding_mask
410 key_padding_mask[0][4] = 1;
411 key_padding_mask[1][3] = 1;
412 key_padding_mask[1][4] = 1;
413 result = model(
414 decoder_input,
415 memory_input,
416 t_mask,
417 m_mask,
418 t_key_padding_mask,
419 key_padding_mask)
420 .detach();
421 ref_output = torch::tensor(
422 {{{2.429757, 0.027358, -0.601351, -0.073816},
423 {2.432692, 0.028583, -0.599263, -0.073634}},
424 {{2.428247, 0.02662, -0.602419, -0.074123},
425 {2.432657, 0.029055, -0.599293, -0.072732}},
426 {{2.431515, 0.027687, -0.600096, -0.074459},
427 {2.433075, 0.028543, -0.598987, -0.073985}}},
428 tensor_options);
429 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
430 ASSERT_TRUE(torch::allclose(
431 result,
432 ref_output,
433 1e-7,
434 1e-5,
435 /*equal_nan=*/true));
436}
437
438TEST_F(TransformerTest, TransformerDecoderLayer) {
439 transformer_decoder_layer_test_helper(
440 /*is_cuda=*/false, /*use_callable_activation=*/false);
441 transformer_decoder_layer_test_helper(
442 /*is_cuda=*/false, /*use_callable_activation=*/true);
443}
444
445TEST_F(TransformerTest, TransformerDecoderLayer_CUDA) {
446 transformer_decoder_layer_test_helper(
447 /*is_cuda=*/true, /*use_callable_activation=*/false);
448 transformer_decoder_layer_test_helper(
449 /*is_cuda=*/true, /*use_callable_activation=*/true);
450}
451
452void transformer_decoder_layer_test_helper_gelu(
453 bool is_cuda,
454 bool use_callable_activation) {
455 torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
456 torch::TensorOptions tensor_options =
457 torch::TensorOptions().dtype(torch::kFloat32).device(device);
458
459 TransformerDecoderLayer model =
460 get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
461 tensor_options, use_callable_activation);
462 if (use_callable_activation) {
463 model.get()->options.activation(
464 [&](const torch::Tensor& t) { return torch::nn::functional::gelu(t); });
465 } else {
466 model.get()->options.activation(torch::kGELU);
467 }
468
469 // deterministic input
470 torch::Tensor decoder_input =
471 torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
472 torch::Tensor memory_input =
473 torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
474 torch::Tensor result = model(decoder_input, memory_input).detach();
475 torch::Tensor ref_output = torch::tensor(
476 {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
477 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
478 ASSERT_TRUE(torch::allclose(
479 result,
480 ref_output,
481 1e-7,
482 1e-5,
483 /*equal_nan=*/true));
484
485 // deterministic input
486 decoder_input =
487 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
488 memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
489 result = model(decoder_input, memory_input).detach();
490 ref_output = torch::tensor(
491 {{{2.415448, 0.054389, -0.610932, -0.0156613}},
492 {{2.415448, 0.054389, -0.610932, -0.0156613}}},
493 tensor_options);
494 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
495 ASSERT_TRUE(torch::allclose(
496 result,
497 ref_output,
498 1e-7,
499 1e-5,
500 /*equal_nan=*/true));
501
502 // deterministic input
503 decoder_input =
504 torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
505 memory_input =
506 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
507 result = model(decoder_input, memory_input).detach();
508 ref_output = torch::tensor(
509 {{{2.338531, 0.087709, -0.65776, 0.080646}},
510 {{2.338531, 0.087709, -0.65776, 0.080646}}},
511 tensor_options);
512 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
513 ASSERT_TRUE(torch::allclose(
514 result,
515 ref_output,
516 1e-7,
517 1e-5,
518 /*equal_nan=*/true));
519
520 // deterministic input
521 decoder_input = torch::tensor(
522 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
523 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
524 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
525 tensor_options);
526 memory_input = torch::tensor(
527 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
528 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
529 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
530 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
531 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
532 tensor_options);
533 result = model(decoder_input, memory_input).detach();
534 ref_output = torch::tensor(
535 {{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
536 {2.42210631, 0.03546578, -0.60679895, -0.05357488}},
537 {{2.41907674, 0.0336104, -0.60892977, -0.05490462},
538 {2.42216881, 0.03586554, -0.6067524, -0.05289126}},
539 {{2.42205716, 0.03488046, -0.60683681, -0.05460596},
540 {2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
541 tensor_options);
542 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
543 ASSERT_TRUE(torch::allclose(
544 result,
545 ref_output,
546 1e-7,
547 1e-5,
548 /*equal_nan=*/true));
549}
550
551TEST_F(TransformerTest, TransformerDecoderLayer_gelu) {
552 transformer_decoder_layer_test_helper_gelu(
553 /*is_cuda=*/false, /*use_callable_activation=*/false);
554 transformer_decoder_layer_test_helper_gelu(
555 /*is_cuda=*/false, /*use_callable_activation=*/true);
556}
557
558TEST_F(TransformerTest, TransformerDecoderLayer_gelu_CUDA) {
559 transformer_decoder_layer_test_helper_gelu(
560 /*is_cuda=*/true, /*use_callable_activation=*/false);
561 transformer_decoder_layer_test_helper_gelu(
562 /*is_cuda=*/true, /*use_callable_activation=*/true);
563}
564
565void transformer_encoder_test_helper(
566 bool is_cuda,
567 bool use_callable_activation) {
568 // this is a deterministic test for TransformerEncoderLayer
569 torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
570 torch::TensorOptions tensor_options =
571 torch::TensorOptions().dtype(torch::kFloat32).device(device);
572
573 TransformerEncoderLayer encoder_layer =
574 get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
575 tensor_options, use_callable_activation);
576
577 TransformerEncoder model(TransformerEncoderOptions(encoder_layer, 1));
578 if (is_cuda) {
579 model->to(torch::kCUDA);
580 }
581
582 torch::Tensor encoder_input = torch::tensor(
583 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
584 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
585 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
586 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
587 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
588 tensor_options);
589 torch::Tensor result = model(encoder_input).detach();
590 torch::Tensor ref_output = torch::tensor(
591 {{{2.428589, 0.020835, -0.602055, -0.085249},
592 {2.427987, 0.021213, -0.602496, -0.084103}},
593 {{2.424689, 0.019155, -0.604793, -0.085672},
594 {2.413863, 0.022211, -0.612486, -0.072490}},
595 {{2.433774, 0.021598, -0.598343, -0.087548},
596 {2.425104, 0.019748, -0.604515, -0.084839}},
597 {{2.436185, 0.022682, -0.596625, -0.087261},
598 {2.433556, 0.021891, -0.598509, -0.086832}},
599 {{2.416246, 0.017512, -0.610712, -0.082961},
600 {2.422901, 0.024187, -0.606178, -0.074929}}},
601 tensor_options);
602 ASSERT_EQ(result.sizes(), ref_output.sizes());
603 ASSERT_TRUE(
604 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
605
606 // all 0 values are NOT masked
607 torch::Tensor mask = torch::zeros({2, 5}, tensor_options) == 1;
608 result = model(
609 encoder_input,
610 /*src_mask=*/torch::Tensor{},
611 /*src_key_padding_mask=*/mask)
612 .detach();
613 ASSERT_EQ(result.sizes(), ref_output.sizes());
614 ASSERT_TRUE(
615 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
616
617 // mask with 0s and 1s
618 mask[0][1] = 1;
619 mask[1][3] = 1;
620 mask[1][4] = 1;
621 result = model(
622 encoder_input,
623 /*src_mask=*/torch::Tensor{},
624 /*src_key_padding_mask=*/mask)
625 .detach();
626 ref_output = torch::tensor(
627 {{{2.429026, 0.020793, -0.601741, -0.085642},
628 {2.428811, 0.021445, -0.601912, -0.084252}},
629 {{2.425009, 0.019155, -0.604566, -0.085899},
630 {2.415408, 0.02249, -0.611415, -0.073}},
631 {{2.434199, 0.021682, -0.598039, -0.087699},
632 {2.42598, 0.019941, -0.603896, -0.085091}},
633 {{2.436457, 0.022736, -0.59643, -0.08736},
634 {2.434021, 0.022093, -0.598179, -0.08679}},
635 {{2.416531, 0.017498, -0.610513, -0.083181},
636 {2.4242, 0.024653, -0.605266, -0.074959}}},
637 tensor_options);
638 ASSERT_EQ(result.sizes(), ref_output.sizes());
639 ASSERT_TRUE(
640 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
641
642 // test case 2, multiple layers no norm
643 model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 2));
644 if (is_cuda) {
645 model->to(torch::kCUDA);
646 }
647 result = model(
648 encoder_input,
649 /*src_mask=*/torch::Tensor{},
650 /*src_key_padding_mask=*/mask)
651 .detach();
652 ref_output = torch::tensor(
653 {{{2.419051, 0.017446, -0.608738, -0.085003},
654 {2.419102, 0.017452, -0.608703, -0.085026}},
655 {{2.419043, 0.017445, -0.608744, -0.084999},
656 {2.419052, 0.017446, -0.608738, -0.085004}},
657 {{2.419067, 0.017448, -0.608727, -0.085010},
658 {2.419098, 0.017452, -0.608706, -0.085024}},
659 {{2.419072, 0.017449, -0.608724, -0.085012},
660 {2.419119, 0.017455, -0.608691, -0.085034}},
661 {{2.419019, 0.017442, -0.608761, -0.084989},
662 {2.419075, 0.017449, -0.608722, -0.085014}}},
663 tensor_options);
664 ASSERT_EQ(result.sizes(), ref_output.sizes());
665 ASSERT_TRUE(
666 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
667
668 model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 6));
669 if (is_cuda) {
670 model->to(torch::kCUDA);
671 }
672 result = model(
673 encoder_input,
674 /*src_mask=*/torch::Tensor{},
675 /*src_key_padding_mask=*/mask)
676 .detach();
677 ref_output = torch::tensor(
678 {{{2.419101, 0.017453, -0.608703, -0.085025},
679 {2.419101, 0.017453, -0.608704, -0.085025}},
680 {{2.419101, 0.017453, -0.608703, -0.085025},
681 {2.419101, 0.017453, -0.608704, -0.085025}},
682 {{2.419101, 0.017453, -0.608703, -0.085025},
683 {2.419101, 0.017453, -0.608704, -0.085025}},
684 {{2.419101, 0.017453, -0.608703, -0.085025},
685 {2.419101, 0.017453, -0.608704, -0.085025}},
686 {{2.419101, 0.017453, -0.608703, -0.085025},
687 {2.419101, 0.017453, -0.608704, -0.085025}}},
688 tensor_options);
689 ASSERT_EQ(result.sizes(), ref_output.sizes());
690 ASSERT_TRUE(
691 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
692
693 // test case 3, multiple layers with norm
694 LayerNorm norm(LayerNormOptions({encoder_layer.get()->options.d_model()}));
695 model = TransformerEncoder(
696 TransformerEncoderOptions(encoder_layer, 2).norm(AnyModule(norm)));
697 if (is_cuda) {
698 model->to(torch::kCUDA);
699 }
700 result = model(
701 encoder_input,
702 /*src_mask=*/torch::Tensor{},
703 /*src_key_padding_mask=*/mask)
704 .detach();
705 ref_output = torch::tensor(
706 {{{1.695949, -0.357635, -0.893077, -0.445238},
707 {1.695955, -0.357639, -0.893050, -0.445266}},
708 {{1.695948, -0.357634, -0.893082, -0.445233},
709 {1.695950, -0.357635, -0.893077, -0.445238}},
710 {{1.695951, -0.357636, -0.893069, -0.445246},
711 {1.695955, -0.357639, -0.893052, -0.445264}},
712 {{1.695952, -0.357636, -0.893066, -0.445249},
713 {1.695957, -0.357641, -0.893041, -0.445276}},
714 {{1.695946, -0.357632, -0.893095, -0.445220},
715 {1.695952, -0.357637, -0.893065, -0.445251}}},
716 tensor_options);
717 ASSERT_EQ(result.sizes(), ref_output.sizes());
718 ASSERT_TRUE(
719 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
720
721 model = TransformerEncoder(
722 TransformerEncoderOptions(encoder_layer, 6).norm(AnyModule(norm)));
723 if (is_cuda) {
724 model->to(torch::kCUDA);
725 }
726 result = model(
727 encoder_input,
728 /*src_mask=*/torch::Tensor{},
729 /*src_key_padding_mask=*/mask)
730 .detach();
731 ref_output = torch::tensor(
732 {{{1.695955, -0.357639, -0.893051, -0.445265},
733 {1.695955, -0.357639, -0.893051, -0.445265}},
734 {{1.695955, -0.357639, -0.893051, -0.445265},
735 {1.695955, -0.357639, -0.893051, -0.445265}},
736 {{1.695955, -0.357639, -0.893051, -0.445265},
737 {1.695955, -0.357639, -0.893051, -0.445265}},
738 {{1.695955, -0.357639, -0.893051, -0.445265},
739 {1.695955, -0.357639, -0.893051, -0.445265}},
740 {{1.695955, -0.357639, -0.893051, -0.445265},
741 {1.695955, -0.357639, -0.893051, -0.445265}}},
742 tensor_options);
743 ASSERT_EQ(result.sizes(), ref_output.sizes());
744 ASSERT_TRUE(
745 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
746}
747
748TEST_F(TransformerTest, TransformerEncoder) {
749 transformer_encoder_test_helper(
750 /*is_cuda=*/false, /*use_callable_activation=*/false);
751 transformer_encoder_test_helper(
752 /*is_cuda=*/false, /*use_callable_activation=*/true);
753}
754
755TEST_F(TransformerTest, TransformerEncoder_CUDA) {
756 transformer_encoder_test_helper(
757 /*is_cuda=*/true, /*use_callable_activation=*/false);
758 transformer_encoder_test_helper(
759 /*is_cuda=*/true, /*use_callable_activation=*/true);
760}
761
762TEST_F(TransformerTest, PrettyPrintTransformerEncoderLayer) {
763 ASSERT_EQ(
764 c10::str(TransformerEncoderLayer(4, 2)),
765 "torch::nn::TransformerEncoderLayerImpl(\n"
766 " (self_attn): torch::nn::MultiheadAttention(\n"
767 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
768 " )\n"
769 " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
770 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
771 " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
772 " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
773 " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
774 " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
775 " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
776 ")");
777}
778
779TEST_F(TransformerTest, PrettyPrintTransformerEncoder) {
780 LayerNorm norm = LayerNorm(LayerNormOptions({4}));
781 TransformerEncoderOptions options(
782 TransformerEncoderOptions(TransformerEncoderLayerOptions(4, 2), 2)
783 .norm(AnyModule(norm)));
784 ASSERT_EQ(
785 c10::str(TransformerEncoder(options)),
786 "torch::nn::TransformerEncoderImpl(\n"
787 " (layers): torch::nn::ModuleList(\n"
788 " (0): torch::nn::TransformerEncoderLayerImpl(\n"
789 " (self_attn): torch::nn::MultiheadAttention(\n"
790 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
791 " )\n"
792 " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
793 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
794 " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
795 " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
796 " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
797 " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
798 " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
799 " )\n"
800 " (1): torch::nn::TransformerEncoderLayerImpl(\n"
801 " (self_attn): torch::nn::MultiheadAttention(\n"
802 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
803 " )\n"
804 " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
805 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
806 " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
807 " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
808 " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
809 " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
810 " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
811 " )\n"
812 " )\n"
813 " (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
814 ")");
815}
816
817TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) {
818 ASSERT_EQ(
819 c10::str(TransformerDecoderLayer(4, 2)),
820 "torch::nn::TransformerDecoderLayerImpl(\n"
821 " (self_attn): torch::nn::MultiheadAttention(\n"
822 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
823 " )\n"
824 " (multihead_attn): torch::nn::MultiheadAttention(\n"
825 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
826 " )\n"
827 " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
828 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
829 " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
830 " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
831 " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
832 " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
833 " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
834 " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
835 " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
836 ")");
837}
838
839void transformer_decoder_test_helper(
840 bool is_cuda,
841 bool use_callable_activation) {
842 // this is a deterministic test for TransformerDecoder
843 torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
844 torch::TensorOptions tensor_options =
845 torch::TensorOptions().dtype(torch::kFloat32).device(device);
846
847 TransformerDecoderLayer decoder_layer =
848 get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
849 tensor_options, use_callable_activation);
850
851 TransformerDecoder model(TransformerDecoderOptions(decoder_layer, 1));
852 if (is_cuda) {
853 model->to(torch::kCUDA);
854 }
855
856 torch::Tensor decoder_input =
857 torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
858 torch::Tensor memory_input =
859 torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
860 torch::Tensor result = model(decoder_input, memory_input).detach();
861 torch::Tensor ref_output = torch::tensor(
862 {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
863 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
864 ASSERT_TRUE(torch::allclose(
865 result,
866 ref_output,
867 1e-7,
868 1e-5,
869 /*equal_nan=*/true));
870
871 // deterministic input
872 decoder_input =
873 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
874 memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
875 result = model(decoder_input, memory_input).detach();
876 ref_output = torch::tensor(
877 {{{2.422245, 0.051716, -0.606338, -0.024756}},
878 {{2.422245, 0.051716, -0.606338, -0.024756}}},
879 tensor_options);
880 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
881 ASSERT_TRUE(torch::allclose(
882 result,
883 ref_output,
884 1e-7,
885 1e-5,
886 /*equal_nan=*/true));
887
888 // deterministic input
889 decoder_input =
890 torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
891 memory_input =
892 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
893 result = model(decoder_input, memory_input).detach();
894 ref_output = torch::tensor(
895 {{{2.343536, 0.085561, -0.654954, 0.074991}},
896 {{2.343536, 0.085561, -0.654954, 0.074991}}},
897 tensor_options);
898 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
899 ASSERT_TRUE(torch::allclose(
900 result,
901 ref_output,
902 1e-7,
903 1e-5,
904 /*equal_nan=*/true));
905
906 // deterministic input
907 decoder_input = torch::tensor(
908 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
909 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
910 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
911 tensor_options);
912 memory_input = torch::tensor(
913 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
914 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
915 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
916 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
917 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
918 tensor_options);
919 result = model(decoder_input, memory_input).detach();
920 ref_output = torch::tensor(
921 {{{2.430065, 0.027862, -0.601136, -0.073096},
922 {2.431935, 0.028907, -0.599809, -0.072488}},
923 {{2.428457, 0.027053, -0.602275, -0.073462},
924 {2.431970, 0.029387, -0.599789, -0.071621}},
925 {{2.431934, 0.028196, -0.599802, -0.073809},
926 {2.432306, 0.028858, -0.599542, -0.072846}}},
927 tensor_options);
928 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
929 ASSERT_TRUE(torch::allclose(
930 result,
931 ref_output,
932 1e-7,
933 1e-5,
934 /*equal_nan=*/true));
935
936 // key_padding_mask
937 torch::Tensor t_mask = {};
938 torch::Tensor m_mask = {};
939 torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
940 result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
941 .detach();
942 ref_output = torch::tensor(
943 {{{2.430065, 0.027862, -0.601136, -0.073096},
944 {2.431935, 0.028907, -0.599809, -0.072488}},
945 {{2.428457, 0.027053, -0.602275, -0.073462},
946 {2.431970, 0.029387, -0.599789, -0.071621}},
947 {{2.431934, 0.028196, -0.599802, -0.073809},
948 {2.432306, 0.028858, -0.599542, -0.072846}}},
949 tensor_options);
950 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
951 ASSERT_TRUE(torch::allclose(
952 result,
953 ref_output,
954 1e-7,
955 1e-5,
956 /*equal_nan=*/true));
957
958 // key_padding_mask
959 key_padding_mask[0][2] = 1;
960 key_padding_mask[1][1] = 1;
961 key_padding_mask[1][2] = 1;
962 result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
963 .detach();
964 ref_output = torch::tensor(
965 {{{2.430025, 0.027643, -0.601164, -0.073476},
966 {2.4323, 0.029375, -0.599553, -0.071881}},
967 {{2.428523, 0.026838, -0.602226, -0.07391},
968 {2.432634, 0.029842, -0.599318, -0.071253}},
969 {{2.432278, 0.028152, -0.599555, -0.074139},
970 {2.432659, 0.029244, -0.599294, -0.072382}}},
971 tensor_options);
972 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
973 ASSERT_TRUE(torch::allclose(
974 result,
975 ref_output,
976 1e-7,
977 1e-5,
978 /*equal_nan=*/true));
979
980 // memory_key_padding_mask
981 torch::Tensor t_key_padding_mask = {};
982 key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
983 result = model(
984 decoder_input,
985 memory_input,
986 t_mask,
987 m_mask,
988 t_key_padding_mask,
989 key_padding_mask)
990 .detach();
991 ref_output = torch::tensor(
992 {{{2.430065, 0.027862, -0.601136, -0.073096},
993 {2.431935, 0.028907, -0.599809, -0.072488}},
994 {{2.428457, 0.027053, -0.602275, -0.073462},
995 {2.431970, 0.029387, -0.599789, -0.071621}},
996 {{2.431934, 0.028196, -0.599802, -0.073809},
997 {2.432306, 0.028858, -0.599542, -0.072846}}},
998 tensor_options);
999 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1000 ASSERT_TRUE(torch::allclose(
1001 result,
1002 ref_output,
1003 1e-7,
1004 1e-5,
1005 /*equal_nan=*/true));
1006
1007 // memory_key_padding_mask
1008 key_padding_mask[0][4] = 1;
1009 key_padding_mask[1][3] = 1;
1010 key_padding_mask[1][4] = 1;
1011 result = model(
1012 decoder_input,
1013 memory_input,
1014 t_mask,
1015 m_mask,
1016 t_key_padding_mask,
1017 key_padding_mask)
1018 .detach();
1019 ref_output = torch::tensor(
1020 {{{2.429757, 0.027358, -0.601351, -0.073816},
1021 {2.432692, 0.028583, -0.599263, -0.073634}},
1022 {{2.428247, 0.02662, -0.602419, -0.074123},
1023 {2.432657, 0.029055, -0.599293, -0.072732}},
1024 {{2.431515, 0.027687, -0.600096, -0.074459},
1025 {2.433075, 0.028543, -0.598987, -0.073985}}},
1026 tensor_options);
1027 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1028 ASSERT_TRUE(torch::allclose(
1029 result,
1030 ref_output,
1031 1e-7,
1032 1e-5,
1033 /*equal_nan=*/true));
1034
1035 // multiple layers no norm
1036 model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 2));
1037 if (is_cuda) {
1038 model->to(torch::kCUDA);
1039 }
1040
1041 decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1042 memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1043 result = model(decoder_input, memory_input).detach();
1044 ref_output = torch::tensor(
1045 {{{2.31316, 0.0950293, -0.671995, 0.102802}}}, tensor_options);
1046 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1047 ASSERT_TRUE(torch::allclose(
1048 result,
1049 ref_output,
1050 1e-7,
1051 1e-5,
1052 /*equal_nan=*/true));
1053
1054 // multiple layers no norm
1055 model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1056 if (is_cuda) {
1057 model->to(torch::kCUDA);
1058 }
1059 // deterministic input
1060 decoder_input = torch::tensor(
1061 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1062 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1063 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1064 tensor_options);
1065 memory_input = torch::tensor(
1066 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1067 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1068 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1069 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1070 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1071 tensor_options);
1072 result = model(decoder_input, memory_input).detach();
1073 ref_output = torch::tensor(
1074 {{{2.42794, 0.026164, -0.60263, -0.0747591},
1075 {2.43113, 0.0279516, -0.600376, -0.0736896}},
1076 {{2.42794, 0.026164, -0.60263, -0.0747591},
1077 {2.43113, 0.0279516, -0.600376, -0.0736896}},
1078 {{2.42794, 0.026164, -0.60263, -0.0747591},
1079 {2.43113, 0.0279516, -0.600376, -0.0736896}}},
1080 tensor_options);
1081 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1082 ASSERT_TRUE(torch::allclose(
1083 result,
1084 ref_output,
1085 1e-7,
1086 1e-5,
1087 /*equal_nan=*/true));
1088
1089 // multiple layers with norm
1090 LayerNorm norm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1091 model = TransformerDecoder(
1092 TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm)));
1093 if (is_cuda) {
1094 model->to(torch::kCUDA);
1095 }
1096
1097 decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1098 memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1099 result = model(decoder_input, memory_input).detach();
1100 ref_output = torch::tensor(
1101 {{{1.66166, -0.326986, -1.01466, -0.320017}}}, tensor_options);
1102 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1103 ASSERT_TRUE(torch::allclose(
1104 result,
1105 ref_output,
1106 1e-7,
1107 1e-5,
1108 /*equal_nan=*/true));
1109
1110 // multiple layers with norm
1111 model = TransformerDecoder(
1112 TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1113 if (is_cuda) {
1114 model->to(torch::kCUDA);
1115 }
1116 // deterministic input
1117 decoder_input = torch::tensor(
1118 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1119 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1120 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1121 tensor_options);
1122 memory_input = torch::tensor(
1123 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1124 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1125 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1126 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1127 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1128 tensor_options);
1129 result = model(decoder_input, memory_input).detach();
1130 ref_output = torch::tensor(
1131 {{{1.69559, -0.357291, -0.894741, -0.443553},
1132 {1.69571, -0.357363, -0.894154, -0.444196}},
1133 {{1.69559, -0.357291, -0.894741, -0.443553},
1134 {1.69571, -0.357363, -0.894154, -0.444196}},
1135 {{1.69559, -0.357291, -0.894741, -0.443553},
1136 {1.69571, -0.357363, -0.894154, -0.444196}}},
1137 tensor_options);
1138 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1139 ASSERT_TRUE(torch::allclose(
1140 result,
1141 ref_output,
1142 1e-7,
1143 1e-5,
1144 /*equal_nan=*/true));
1145
1146 // gelu activation test cases
1147 decoder_layer.get()->options.activation(torch::kGELU);
1148 model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 1));
1149 if (is_cuda) {
1150 model->to(torch::kCUDA);
1151 }
1152
1153 // deterministic input
1154 decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1155 memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1156 result = model(decoder_input, memory_input).detach();
1157 ref_output = torch::tensor(
1158 {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
1159 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1160 ASSERT_TRUE(torch::allclose(
1161 result,
1162 ref_output,
1163 1e-7,
1164 1e-5,
1165 /*equal_nan=*/true));
1166
1167 // deterministic input
1168 decoder_input =
1169 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1170 memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
1171 result = model(decoder_input, memory_input).detach();
1172 ref_output = torch::tensor(
1173 {{{2.415448, 0.054389, -0.610932, -0.0156613}},
1174 {{2.415448, 0.054389, -0.610932, -0.0156613}}},
1175 tensor_options);
1176 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1177 ASSERT_TRUE(torch::allclose(
1178 result,
1179 ref_output,
1180 1e-7,
1181 1e-5,
1182 /*equal_nan=*/true));
1183
1184 // deterministic input
1185 decoder_input =
1186 torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
1187 memory_input =
1188 torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1189 result = model(decoder_input, memory_input).detach();
1190 ref_output = torch::tensor(
1191 {{{2.338531, 0.087709, -0.65776, 0.080646}},
1192 {{2.338531, 0.087709, -0.65776, 0.080646}}},
1193 tensor_options);
1194 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1195 ASSERT_TRUE(torch::allclose(
1196 result,
1197 ref_output,
1198 1e-7,
1199 1e-5,
1200 /*equal_nan=*/true));
1201
1202 // deterministic input
1203 decoder_input = torch::tensor(
1204 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1205 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1206 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1207 tensor_options);
1208 memory_input = torch::tensor(
1209 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1210 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1211 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1212 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1213 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1214 tensor_options);
1215 result = model(decoder_input, memory_input).detach();
1216 ref_output = torch::tensor(
1217 {{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
1218 {2.42210631, 0.03546578, -0.60679895, -0.05357488}},
1219 {{2.41907674, 0.0336104, -0.60892977, -0.05490462},
1220 {2.42216881, 0.03586554, -0.6067524, -0.05289126}},
1221 {{2.42205716, 0.03488046, -0.60683681, -0.05460596},
1222 {2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
1223 tensor_options);
1224 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1225 ASSERT_TRUE(torch::allclose(
1226 result,
1227 ref_output,
1228 1e-7,
1229 1e-5,
1230 /*equal_nan=*/true));
1231
1232 // Multiple layers no norm
1233 model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1234 if (is_cuda) {
1235 model->to(torch::kCUDA);
1236 }
1237 decoder_input = torch::tensor(
1238 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1239 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1240 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1241 tensor_options);
1242 memory_input = torch::tensor(
1243 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1244 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1245 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1246 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1247 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1248 tensor_options);
1249 result = model(decoder_input, memory_input).detach();
1250 ref_output = torch::tensor(
1251 {{{2.41859, 0.0328114, -0.609269, -0.0560386},
1252 {2.42138, 0.034598, -0.607316, -0.0546574}},
1253 {{2.41859, 0.0328114, -0.609269, -0.0560386},
1254 {2.42138, 0.034598, -0.607316, -0.0546574}},
1255 {{2.41859, 0.0328114, -0.609269, -0.0560386},
1256 {2.42138, 0.034598, -0.607316, -0.0546574}}},
1257 tensor_options);
1258 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1259 ASSERT_TRUE(torch::allclose(
1260 result,
1261 ref_output,
1262 1e-7,
1263 1e-5,
1264 /*equal_nan=*/true));
1265
1266 // Multiple layers with norm
1267 norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1268 model = TransformerDecoder(
1269 TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1270 if (is_cuda) {
1271 model->to(torch::kCUDA);
1272 }
1273
1274 decoder_input = torch::tensor(
1275 {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1276 {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1277 {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1278 tensor_options);
1279 memory_input = torch::tensor(
1280 {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1281 {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1282 {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1283 {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1284 {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1285 tensor_options);
1286 result = model(decoder_input, memory_input).detach();
1287 ref_output = torch::tensor(
1288 {{{1.69298, -0.355163, -0.906375, -0.431439},
1289 {1.69305, -0.355195, -0.906062, -0.431791}},
1290 {{1.69298, -0.355163, -0.906375, -0.431439},
1291 {1.69305, -0.355195, -0.906062, -0.431791}},
1292 {{1.69298, -0.355163, -0.906375, -0.431439},
1293 {1.69305, -0.355195, -0.906062, -0.431791}}},
1294 tensor_options);
1295 ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1296 ASSERT_TRUE(torch::allclose(
1297 result,
1298 ref_output,
1299 1e-7,
1300 1e-5,
1301 /*equal_nan=*/true));
1302}
1303
1304TEST_F(TransformerTest, TransformerDecoder) {
1305 transformer_decoder_test_helper(
1306 /*is_cuda=*/false, /*use_callable_activation=*/false);
1307 transformer_decoder_test_helper(
1308 /*is_cuda=*/false, /*use_callable_activation=*/true);
1309}
1310
1311TEST_F(TransformerTest, TransformerDecoder_CUDA) {
1312 transformer_decoder_test_helper(
1313 /*is_cuda=*/true, /*use_callable_activation=*/false);
1314 transformer_decoder_test_helper(
1315 /*is_cuda=*/true, /*use_callable_activation=*/true);
1316}
1317
1318TEST_F(TransformerTest, PrettyPrintTransformerDecoder) {
1319 LayerNorm norm = LayerNorm(LayerNormOptions({4}));
1320 TransformerDecoderOptions options(
1321 TransformerDecoderOptions(TransformerDecoderLayerOptions(4, 2), 2)
1322 .norm(AnyModule(norm)));
1323 ASSERT_EQ(
1324 c10::str(TransformerDecoder(options)),
1325 "torch::nn::TransformerDecoderImpl(\n"
1326 " (layers): torch::nn::ModuleList(\n"
1327 " (0): torch::nn::TransformerDecoderLayerImpl(\n"
1328 " (self_attn): torch::nn::MultiheadAttention(\n"
1329 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1330 " )\n"
1331 " (multihead_attn): torch::nn::MultiheadAttention(\n"
1332 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1333 " )\n"
1334 " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1335 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1336 " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1337 " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1338 " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1339 " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1340 " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1341 " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1342 " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1343 " )\n"
1344 " (1): torch::nn::TransformerDecoderLayerImpl(\n"
1345 " (self_attn): torch::nn::MultiheadAttention(\n"
1346 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1347 " )\n"
1348 " (multihead_attn): torch::nn::MultiheadAttention(\n"
1349 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1350 " )\n"
1351 " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1352 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1353 " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1354 " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1355 " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1356 " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1357 " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1358 " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1359 " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1360 " )\n"
1361 " )\n"
1362 " (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1363 ")");
1364}
1365
1366void transformer_test_helper(bool is_cuda, bool use_callable_activation) {
1367 // this is a deterministic test for Transformere
1368 torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
1369 torch::TensorOptions tensor_options =
1370 torch::TensorOptions().dtype(torch::kFloat32).device(device);
1371
1372 // transformer created encoder/decoder
1373 auto options = TransformerOptions()
1374 .d_model(4)
1375 .nhead(2)
1376 .num_encoder_layers(2)
1377 .num_decoder_layers(1)
1378 .dim_feedforward(16)
1379 .dropout(0.0)
1380 .activation(torch::kReLU);
1381 if (use_callable_activation) {
1382 options.activation(
1383 [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
1384 }
1385 Transformer model(options);
1386
1387 set_parameter_to_constants<Transformer>(model, tensor_options);
1388 if (tensor_options.device() == torch::kCUDA) {
1389 model->to(torch::kCUDA);
1390 }
1391
1392 // transformer with customized encoder/decoder
1393 LayerNorm enorm(LayerNormOptions({4}));
1394 TransformerEncoder encoder(
1395 TransformerEncoderOptions(
1396 TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1397 2)
1398 .norm(AnyModule(enorm)));
1399
1400 LayerNorm dnorm(LayerNormOptions({4}));
1401 TransformerDecoder decoder(
1402 TransformerDecoderOptions(
1403 TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1404 1)
1405 .norm(AnyModule(dnorm)));
1406
1407 Transformer model_cus(TransformerOptions()
1408 .d_model(4)
1409 .nhead(2)
1410 .custom_encoder(AnyModule(encoder))
1411 .custom_decoder(AnyModule(decoder)));
1412
1413 set_parameter_to_constants<Transformer>(model_cus, tensor_options);
1414 if (tensor_options.device() == torch::kCUDA) {
1415 model_cus->to(torch::kCUDA);
1416 }
1417
1418 // test cases
1419 torch::Tensor src = torch::tensor(
1420 {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1421 {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}},
1422 {{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}},
1423 tensor_options);
1424
1425 torch::Tensor tgt = torch::tensor(
1426 {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1427 {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}},
1428 tensor_options);
1429
1430 torch::Tensor ref_output = torch::tensor(
1431 {{{2.695875, 0.347114, -0.044355, -0.549541},
1432 {2.696091, 0.347015, -0.044770, -0.548522}},
1433 {{2.695875, 0.347114, -0.044355, -0.549541},
1434 {2.696091, 0.347015, -0.044770, -0.548522}}},
1435 tensor_options);
1436 torch::Tensor result = model(src, tgt);
1437 torch::Tensor result_cus = model_cus(src, tgt);
1438 ASSERT_EQ(result.sizes(), ref_output.sizes());
1439 ASSERT_TRUE(result.equal(result_cus));
1440 ASSERT_TRUE(
1441 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1442
1443 torch::Tensor src_mask =
1444 Transformer::Impl::generate_square_subsequent_mask(src.size(0))
1445 .to(tensor_options);
1446 ref_output = torch::tensor(
1447 {{{2.695875, 0.347114, -0.044355, -0.549541},
1448 {2.696091, 0.347015, -0.044770, -0.548522}},
1449 {{2.695875, 0.347114, -0.044355, -0.549541},
1450 {2.696091, 0.347015, -0.044770, -0.548522}}},
1451 tensor_options);
1452 result = model(src, tgt, src_mask);
1453 result_cus = model_cus(src, tgt, src_mask);
1454 ASSERT_EQ(result.sizes(), ref_output.sizes());
1455 ASSERT_TRUE(result.equal(result_cus));
1456 ASSERT_TRUE(
1457 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1458
1459 torch::Tensor tgt_key_padding_mask =
1460 torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1;
1461 tgt_key_padding_mask[0][0] = 1;
1462 tgt_key_padding_mask[1][1] = 1;
1463 ref_output = torch::tensor(
1464 {{{2.696114, 0.347004, -0.044813, -0.548417},
1465 {2.696091, 0.347015, -0.044770, -0.548522}},
1466 {{2.696114, 0.347004, -0.044813, -0.548417},
1467 {2.696091, 0.347015, -0.044770, -0.548522}}},
1468 tensor_options);
1469 result = model(
1470 src,
1471 tgt,
1472 src_mask,
1473 torch::Tensor(),
1474 torch::Tensor(),
1475 torch::Tensor(),
1476 tgt_key_padding_mask);
1477 result_cus = model_cus(
1478 src,
1479 tgt,
1480 src_mask,
1481 torch::Tensor(),
1482 torch::Tensor(),
1483 torch::Tensor(),
1484 tgt_key_padding_mask);
1485 ASSERT_EQ(result.sizes(), ref_output.sizes());
1486 ASSERT_TRUE(result.equal(result_cus));
1487 ASSERT_TRUE(
1488 torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1489}
1490
1491TEST_F(TransformerTest, Transformer) {
1492 transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/false);
1493 transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/true);
1494}
1495
1496TEST_F(TransformerTest, Transformer_CUDA) {
1497 transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/false);
1498 transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/true);
1499}
1500
1501TEST_F(TransformerTest, TransformerArgsCorrectness) {
1502 Transformer model(TransformerOptions()
1503 .d_model(4)
1504 .nhead(2)
1505 .num_encoder_layers(2)
1506 .num_decoder_layers(1)
1507 .dim_feedforward(16)
1508 .dropout(0.0)
1509 .activation(torch::kReLU));
1510
1511 torch::Tensor src = torch::randn({2, 3, 4});
1512 torch::Tensor tgt = torch::randn({3, 2, 4});
1513
1514 ASSERT_THROWS_WITH(
1515 model(src, tgt), "src and tgt should have equal batch size");
1516
1517 tgt = torch::randn({2, 3, 3});
1518 ASSERT_THROWS_WITH(
1519 model(src, tgt), "src and tgt should have same feature size as d_model");
1520
1521 src = torch::randn({2, 3});
1522 ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have 3 dimensions");
1523}
1524