1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <torch/torch.h> |
5 | |
6 | #include <algorithm> |
7 | #include <memory> |
8 | #include <vector> |
9 | |
10 | #include <test/cpp/api/support.h> |
11 | |
12 | using namespace torch::nn; |
13 | using namespace torch::test; |
14 | |
15 | struct SequentialTest : torch::test::SeedingFixture {}; |
16 | |
17 | TEST_F(SequentialTest, CanContainThings) { |
18 | Sequential sequential(Linear(3, 4), ReLU(), BatchNorm1d(3)); |
19 | } |
20 | |
21 | TEST_F(SequentialTest, ConstructsFromSharedPointer) { |
22 | struct M : torch::nn::Module { |
23 | explicit M(int value_) : value(value_) {} |
24 | int value; |
25 | int forward() { |
26 | return value; |
27 | } |
28 | }; |
29 | Sequential sequential( |
30 | std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)); |
31 | ASSERT_EQ(sequential->size(), 3); |
32 | |
33 | Sequential sequential_named( |
34 | {{"m1" , std::make_shared<M>(1)}, |
35 | {std::string("m2" ), std::make_shared<M>(2)}, |
36 | {"m3" , std::make_shared<M>(3)}}); |
37 | ASSERT_EQ(sequential->size(), 3); |
38 | } |
39 | |
40 | TEST_F(SequentialTest, ConstructsFromConcreteType) { |
41 | static int copy_count; |
42 | |
43 | struct M : torch::nn::Module { |
44 | explicit M(int value_) : value(value_) {} |
45 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
46 | M(const M& other) : torch::nn::Module(other) { |
47 | copy_count++; |
48 | } |
49 | int value; |
50 | int forward() { |
51 | return value; |
52 | } |
53 | }; |
54 | |
55 | copy_count = 0; |
56 | Sequential sequential(M(1), M(2), M(3)); |
57 | ASSERT_EQ(sequential->size(), 3); |
58 | // NOTE: The current implementation expects each module to be copied exactly |
59 | // once, which happens when the module is passed into `std::make_shared<T>()`. |
60 | // TODO: Find a way to avoid copying, and then delete the copy constructor of |
61 | // `M`. |
62 | ASSERT_EQ(copy_count, 3); |
63 | |
64 | copy_count = 0; |
65 | Sequential sequential_named( |
66 | {{"m1" , M(1)}, {std::string("m2" ), M(2)}, {"m3" , M(3)}}); |
67 | ASSERT_EQ(sequential->size(), 3); |
68 | ASSERT_EQ(copy_count, 3); |
69 | } |
70 | |
71 | TEST_F(SequentialTest, ConstructsFromModuleHolder) { |
72 | struct MImpl : torch::nn::Module { |
73 | explicit MImpl(int value_) : value(value_) {} |
74 | int forward() { |
75 | return value; |
76 | } |
77 | int value; |
78 | }; |
79 | |
80 | struct M : torch::nn::ModuleHolder<MImpl> { |
81 | using torch::nn::ModuleHolder<MImpl>::ModuleHolder; |
82 | using torch::nn::ModuleHolder<MImpl>::get; |
83 | }; |
84 | |
85 | Sequential sequential(M(1), M(2), M(3)); |
86 | ASSERT_EQ(sequential->size(), 3); |
87 | |
88 | Sequential sequential_named( |
89 | {{"m1" , M(1)}, {std::string("m2" ), M(2)}, {"m3" , M(3)}}); |
90 | ASSERT_EQ(sequential->size(), 3); |
91 | } |
92 | |
93 | TEST_F(SequentialTest, PushBackAddsAnElement) { |
94 | struct M : torch::nn::Module { |
95 | explicit M(int value_) : value(value_) {} |
96 | int forward() { |
97 | return value; |
98 | } |
99 | int value; |
100 | }; |
101 | |
102 | // Test unnamed submodules |
103 | Sequential sequential; |
104 | ASSERT_EQ(sequential->size(), 0); |
105 | ASSERT_TRUE(sequential->is_empty()); |
106 | sequential->push_back(Linear(3, 4)); |
107 | ASSERT_EQ(sequential->size(), 1); |
108 | sequential->push_back(std::make_shared<M>(1)); |
109 | ASSERT_EQ(sequential->size(), 2); |
110 | sequential->push_back(M(2)); |
111 | ASSERT_EQ(sequential->size(), 3); |
112 | |
113 | // Mix named and unnamed submodules |
114 | Sequential sequential_named; |
115 | ASSERT_EQ(sequential_named->size(), 0); |
116 | ASSERT_TRUE(sequential_named->is_empty()); |
117 | |
118 | sequential_named->push_back(Linear(3, 4)); |
119 | ASSERT_EQ(sequential_named->size(), 1); |
120 | ASSERT_EQ(sequential_named->named_children()[0].key(), "0" ); |
121 | sequential_named->push_back(std::string("linear2" ), Linear(3, 4)); |
122 | ASSERT_EQ(sequential_named->size(), 2); |
123 | ASSERT_EQ(sequential_named->named_children()[1].key(), "linear2" ); |
124 | |
125 | sequential_named->push_back("shared_m1" , std::make_shared<M>(1)); |
126 | ASSERT_EQ(sequential_named->size(), 3); |
127 | ASSERT_EQ(sequential_named->named_children()[2].key(), "shared_m1" ); |
128 | sequential_named->push_back(std::make_shared<M>(1)); |
129 | ASSERT_EQ(sequential_named->size(), 4); |
130 | ASSERT_EQ(sequential_named->named_children()[3].key(), "3" ); |
131 | |
132 | sequential_named->push_back(M(1)); |
133 | ASSERT_EQ(sequential_named->size(), 5); |
134 | ASSERT_EQ(sequential_named->named_children()[4].key(), "4" ); |
135 | sequential_named->push_back(std::string("m2" ), M(1)); |
136 | ASSERT_EQ(sequential_named->size(), 6); |
137 | ASSERT_EQ(sequential_named->named_children()[5].key(), "m2" ); |
138 | |
139 | // named and unnamed AnyModule's |
140 | Sequential sequential_any; |
141 | auto a = torch::nn::AnyModule(torch::nn::Linear(1, 2)); |
142 | ASSERT_EQ(sequential_any->size(), 0); |
143 | ASSERT_TRUE(sequential_any->is_empty()); |
144 | sequential_any->push_back(a); |
145 | ASSERT_EQ(sequential_any->size(), 1); |
146 | ASSERT_EQ(sequential_any->named_children()[0].key(), "0" ); |
147 | sequential_any->push_back("fc" , a); |
148 | ASSERT_EQ(sequential_any->size(), 2); |
149 | ASSERT_EQ(sequential_any->named_children()[1].key(), "fc" ); |
150 | } |
151 | |
152 | TEST_F(SequentialTest, AccessWithAt) { |
153 | struct M : torch::nn::Module { |
154 | explicit M(int value_) : value(value_) {} |
155 | int forward() { |
156 | return value; |
157 | } |
158 | int value; |
159 | }; |
160 | std::vector<std::shared_ptr<M>> modules = { |
161 | std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)}; |
162 | |
163 | Sequential sequential; |
164 | for (auto& module : modules) { |
165 | sequential->push_back(module); |
166 | } |
167 | ASSERT_EQ(sequential->size(), 3); |
168 | |
169 | // returns the correct module for a given index |
170 | for (const auto i : c10::irange(modules.size())) { |
171 | ASSERT_EQ(&sequential->at<M>(i), modules[i].get()); |
172 | } |
173 | |
174 | // throws for a bad index |
175 | ASSERT_THROWS_WITH( |
176 | sequential->at<M>(modules.size() + 1), "Index out of range" ); |
177 | ASSERT_THROWS_WITH( |
178 | sequential->at<M>(modules.size() + 1000000), "Index out of range" ); |
179 | } |
180 | |
181 | TEST_F(SequentialTest, AccessWithPtr) { |
182 | struct M : torch::nn::Module { |
183 | explicit M(int value_) : value(value_) {} |
184 | int forward() { |
185 | return value; |
186 | } |
187 | int value; |
188 | }; |
189 | std::vector<std::shared_ptr<M>> modules = { |
190 | std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)}; |
191 | |
192 | Sequential sequential; |
193 | for (auto& module : modules) { |
194 | sequential->push_back(module); |
195 | } |
196 | ASSERT_EQ(sequential->size(), 3); |
197 | |
198 | // returns the correct module for a given index |
199 | for (const auto i : c10::irange(modules.size())) { |
200 | ASSERT_EQ(sequential->ptr(i).get(), modules[i].get()); |
201 | ASSERT_EQ(sequential[i].get(), modules[i].get()); |
202 | ASSERT_EQ(sequential->ptr<M>(i).get(), modules[i].get()); |
203 | } |
204 | |
205 | // throws for a bad index |
206 | ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1), "Index out of range" ); |
207 | ASSERT_THROWS_WITH( |
208 | sequential->ptr(modules.size() + 1000000), "Index out of range" ); |
209 | } |
210 | |
211 | TEST_F(SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) { |
212 | Sequential empty; |
213 | ASSERT_THROWS_WITH( |
214 | empty->forward<int>(), "Cannot call forward() on an empty Sequential" ); |
215 | } |
216 | |
217 | TEST_F(SequentialTest, CallingForwardChainsCorrectly) { |
218 | struct MockModule : torch::nn::Module { |
219 | explicit MockModule(int value) : expected(value) {} |
220 | int expected; |
221 | int forward(int value) { |
222 | assert(value == expected); |
223 | return value + 1; |
224 | } |
225 | }; |
226 | |
227 | Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3}); |
228 | |
229 | ASSERT_EQ(sequential->forward<int>(1), 4); |
230 | } |
231 | |
232 | TEST_F(SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) { |
233 | struct M : public torch::nn::Module { |
234 | int forward() { |
235 | return 5; |
236 | } |
237 | }; |
238 | |
239 | Sequential sequential(M{}); |
240 | ASSERT_EQ(sequential->forward<int>(), 5); |
241 | ASSERT_THROWS_WITH( |
242 | sequential->forward<float>(), |
243 | "The type of the return value is int, but you asked for type float" ); |
244 | } |
245 | |
246 | TEST_F(SequentialTest, TheReturnTypeOfForwardDefaultsToTensor) { |
247 | struct M : public torch::nn::Module { |
248 | torch::Tensor forward(torch::Tensor v) { |
249 | return v; |
250 | } |
251 | }; |
252 | |
253 | Sequential sequential(M{}); |
254 | auto variable = torch::ones({3, 3}, torch::requires_grad()); |
255 | ASSERT_TRUE(sequential->forward(variable).equal(variable)); |
256 | } |
257 | |
258 | TEST_F(SequentialTest, ForwardReturnsTheLastValue) { |
259 | torch::manual_seed(0); |
260 | Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100)); |
261 | |
262 | auto x = torch::randn({1000, 10}, torch::requires_grad()); |
263 | auto y = sequential->forward(x); |
264 | ASSERT_EQ(y.ndimension(), 2); |
265 | ASSERT_EQ(y.size(0), 1000); |
266 | ASSERT_EQ(y.size(1), 100); |
267 | } |
268 | |
269 | TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) { |
270 | Sequential sequential( |
271 | Linear(10, 3), |
272 | Conv2d(1, 2, 3), |
273 | Dropout(0.5), |
274 | BatchNorm2d(5), |
275 | Embedding(4, 10), |
276 | LSTM(4, 5)); |
277 | } |
278 | |
279 | TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) { |
280 | struct A : torch::nn::Module { |
281 | int forward(int x) { |
282 | return x; |
283 | } |
284 | }; |
285 | struct B : torch::nn::Module { |
286 | int forward(int x) { |
287 | return x; |
288 | } |
289 | }; |
290 | struct C : torch::nn::Module { |
291 | int forward(int x) { |
292 | return x; |
293 | } |
294 | }; |
295 | struct D : torch::nn::Module { |
296 | int forward(int x) { |
297 | return x; |
298 | } |
299 | }; |
300 | Sequential a(A{}, B{}); |
301 | Sequential b(C{}, D{}); |
302 | a->extend(*b); |
303 | |
304 | ASSERT_EQ(a->size(), 4); |
305 | ASSERT_TRUE(a[0]->as<A>()); |
306 | ASSERT_TRUE(a[1]->as<B>()); |
307 | ASSERT_TRUE(a[2]->as<C>()); |
308 | ASSERT_TRUE(a[3]->as<D>()); |
309 | |
310 | ASSERT_EQ(b->size(), 2); |
311 | ASSERT_TRUE(b[0]->as<C>()); |
312 | ASSERT_TRUE(b[1]->as<D>()); |
313 | |
314 | std::vector<std::shared_ptr<A>> c = { |
315 | std::make_shared<A>(), std::make_shared<A>()}; |
316 | b->extend(c); |
317 | |
318 | ASSERT_EQ(b->size(), 4); |
319 | ASSERT_TRUE(b[0]->as<C>()); |
320 | ASSERT_TRUE(b[1]->as<D>()); |
321 | ASSERT_TRUE(b[2]->as<A>()); |
322 | ASSERT_TRUE(b[3]->as<A>()); |
323 | } |
324 | |
325 | TEST_F(SequentialTest, HasReferenceSemantics) { |
326 | Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5)); |
327 | Sequential second(first); |
328 | |
329 | ASSERT_EQ(first.get(), second.get()); |
330 | ASSERT_EQ(first->size(), second->size()); |
331 | ASSERT_TRUE(std::equal( |
332 | first->begin(), |
333 | first->end(), |
334 | second->begin(), |
335 | [](const AnyModule& first, const AnyModule& second) { |
336 | return &first == &second; |
337 | })); |
338 | } |
339 | |
340 | TEST_F(SequentialTest, IsCloneable) { |
341 | Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3)); |
342 | Sequential clone = |
343 | std::dynamic_pointer_cast<SequentialImpl>(sequential->clone()); |
344 | ASSERT_EQ(sequential->size(), clone->size()); |
345 | |
346 | for (size_t i = 0; i < sequential->size(); ++i) { |
347 | // The modules should be the same kind (type). |
348 | ASSERT_EQ(sequential[i]->name(), clone[i]->name()); |
349 | // But not pointer-equal (distinct objects). |
350 | ASSERT_NE(sequential[i], clone[i]); |
351 | } |
352 | |
353 | // Verify that the clone is deep, i.e. parameters of modules are cloned too. |
354 | |
355 | torch::NoGradGuard no_grad; |
356 | |
357 | auto params1 = sequential->named_parameters(); |
358 | auto params2 = clone->named_parameters(); |
359 | ASSERT_EQ(params1.size(), params2.size()); |
360 | for (auto& param : params1) { |
361 | ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()])); |
362 | ASSERT_EQ(param->device(), params2[param.key()].device()); |
363 | ASSERT_TRUE(param->allclose(params2[param.key()])); |
364 | param->add_(2); |
365 | } |
366 | for (auto& param : params1) { |
367 | ASSERT_FALSE(param->allclose(params2[param.key()])); |
368 | } |
369 | } |
370 | |
371 | TEST_F(SequentialTest, RegistersElementsAsSubmodules) { |
372 | Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5)); |
373 | |
374 | auto modules = sequential->children(); |
375 | ASSERT_TRUE(modules[0]->as<Linear>()); |
376 | ASSERT_TRUE(modules[1]->as<Conv2d>()); |
377 | ASSERT_TRUE(modules[2]->as<Dropout2d>()); |
378 | } |
379 | |
380 | TEST_F(SequentialTest, CloneToDevice_CUDA) { |
381 | Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm1d(3)); |
382 | torch::Device device(torch::kCUDA, 0); |
383 | Sequential clone = |
384 | std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device)); |
385 | for (const auto& p : clone->parameters()) { |
386 | ASSERT_EQ(p.device(), device); |
387 | } |
388 | for (const auto& b : clone->buffers()) { |
389 | ASSERT_EQ(b.device(), device); |
390 | } |
391 | } |
392 | |
393 | TEST_F(SequentialTest, PrettyPrintSequential) { |
394 | Sequential sequential( |
395 | Linear(10, 3), |
396 | Conv2d(1, 2, 3), |
397 | Dropout(0.5), |
398 | BatchNorm2d(5), |
399 | Embedding(4, 10), |
400 | LSTM(4, 5)); |
401 | ASSERT_EQ( |
402 | c10::str(sequential), |
403 | "torch::nn::Sequential(\n" |
404 | " (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n" |
405 | " (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n" |
406 | " (2): torch::nn::Dropout(p=0.5, inplace=false)\n" |
407 | " (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n" |
408 | " (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" |
409 | " (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n" |
410 | ")" ); |
411 | |
412 | Sequential sequential_named( |
413 | {{"linear" , Linear(10, 3)}, |
414 | {"conv2d" , Conv2d(1, 2, 3)}, |
415 | {"dropout" , Dropout(0.5)}, |
416 | {"batchnorm2d" , BatchNorm2d(5)}, |
417 | {"embedding" , Embedding(4, 10)}, |
418 | {"lstm" , LSTM(4, 5)}}); |
419 | ASSERT_EQ( |
420 | c10::str(sequential_named), |
421 | "torch::nn::Sequential(\n" |
422 | " (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n" |
423 | " (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n" |
424 | " (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n" |
425 | " (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n" |
426 | " (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n" |
427 | " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n" |
428 | ")" ); |
429 | } |
430 | |
431 | TEST_F(SequentialTest, ModuleForwardMethodOptionalArg) { |
432 | { |
433 | Sequential sequential( |
434 | Identity(), |
435 | ConvTranspose1d(ConvTranspose1dOptions(3, 2, 3).stride(1).bias(false))); |
436 | std::dynamic_pointer_cast<ConvTranspose1dImpl>(sequential[1]) |
437 | ->weight.set_data(torch::arange(18.).reshape({3, 2, 3})); |
438 | auto x = torch::arange(30.).reshape({2, 3, 5}); |
439 | auto y = sequential->forward(x); |
440 | auto expected = torch::tensor( |
441 | {{{150., 333., 552., 615., 678., 501., 276.}, |
442 | {195., 432., 714., 804., 894., 654., 357.}}, |
443 | {{420., 918., 1497., 1560., 1623., 1176., 636.}, |
444 | {600., 1287., 2064., 2154., 2244., 1599., 852.}}}); |
445 | ASSERT_TRUE(torch::allclose(y, expected)); |
446 | } |
447 | { |
448 | Sequential sequential( |
449 | Identity(), |
450 | ConvTranspose2d(ConvTranspose2dOptions(3, 2, 3).stride(1).bias(false))); |
451 | std::dynamic_pointer_cast<ConvTranspose2dImpl>(sequential[1]) |
452 | ->weight.set_data(torch::arange(54.).reshape({3, 2, 3, 3})); |
453 | auto x = torch::arange(75.).reshape({1, 3, 5, 5}); |
454 | auto y = sequential->forward(x); |
455 | auto expected = torch::tensor( |
456 | {{{{2250., 4629., 7140., 7311., 7482., 5133., 2640.}, |
457 | {4995., 10272., 15837., 16206., 16575., 11364., 5841.}, |
458 | {8280., 17019., 26226., 26820., 27414., 18783., 9648.}, |
459 | {9225., 18954., 29196., 29790., 30384., 20808., 10683.}, |
460 | {10170., 20889., 32166., 32760., 33354., 22833., 11718.}, |
461 | {7515., 15420., 23721., 24144., 24567., 16800., 8613.}, |
462 | {4140., 8487., 13044., 13269., 13494., 9219., 4722.}}, |
463 | {{2925., 6006., 9246., 9498., 9750., 6672., 3423.}, |
464 | {6480., 13296., 20454., 20985., 21516., 14712., 7542.}, |
465 | {10710., 21960., 33759., 34596., 35433., 24210., 12402.}, |
466 | {12060., 24705., 37944., 38781., 39618., 27045., 13842.}, |
467 | {13410., 27450., 42129., 42966., 43803., 29880., 15282.}, |
468 | {9810., 20064., 30768., 31353., 31938., 21768., 11124.}, |
469 | {5355., 10944., 16770., 17076., 17382., 11838., 6045.}}}}); |
470 | ASSERT_TRUE(torch::allclose(y, expected)); |
471 | } |
472 | { |
473 | Sequential sequential( |
474 | Identity(), |
475 | ConvTranspose3d(ConvTranspose3dOptions(2, 2, 2).stride(1).bias(false))); |
476 | std::dynamic_pointer_cast<ConvTranspose3dImpl>(sequential[1]) |
477 | ->weight.set_data(torch::arange(32.).reshape({2, 2, 2, 2, 2})); |
478 | auto x = torch::arange(16.).reshape({1, 2, 2, 2, 2}); |
479 | auto y = sequential->forward(x); |
480 | auto expected = torch::tensor( |
481 | {{{{{128., 280., 154.}, {304., 664., 364.}, {184., 400., 218.}}, |
482 | {{352., 768., 420.}, {832., 1808., 984.}, {496., 1072., 580.}}, |
483 | {{256., 552., 298.}, {592., 1272., 684.}, {344., 736., 394.}}}, |
484 | {{{192., 424., 234.}, {464., 1016., 556.}, {280., 608., 330.}}, |
485 | {{544., 1184., 644.}, {1280., 2768., 1496.}, {752., 1616., 868.}}, |
486 | {{384., 824., 442.}, {880., 1880., 1004.}, {504., 1072., 570.}}}}}); |
487 | ASSERT_TRUE(torch::allclose(y, expected)); |
488 | } |
489 | { |
490 | auto weight = torch::tensor({{1., 2.3, 3.}, {4., 5.1, 6.3}}); |
491 | Sequential sequential(Identity(), EmbeddingBag::from_pretrained(weight)); |
492 | auto x = torch::tensor({{1, 0}}, torch::kLong); |
493 | auto y = sequential->forward(x); |
494 | auto expected = torch::tensor({2.5000, 3.7000, 4.6500}); |
495 | ASSERT_TRUE(torch::allclose(y, expected)); |
496 | } |
497 | { |
498 | torch::manual_seed(0); |
499 | |
500 | int64_t embed_dim = 8; |
501 | int64_t num_heads = 4; |
502 | int64_t batch_size = 8; |
503 | int64_t src_len = 3; |
504 | int64_t tgt_len = 1; |
505 | |
506 | auto query = torch::ones({batch_size, tgt_len, embed_dim}); |
507 | auto key = torch::ones({batch_size, src_len, embed_dim}); |
508 | // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
509 | auto value = key; |
510 | |
511 | Sequential sequential(MultiheadAttention(embed_dim, num_heads)); |
512 | auto output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>( |
513 | query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1)); |
514 | |
515 | auto attn_output = std::get<0>(output); |
516 | auto attn_output_expected = torch::tensor( |
517 | {{{0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, |
518 | {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, |
519 | {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, |
520 | {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, |
521 | {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, |
522 | {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, |
523 | {0.0674, -0.0056, 0.1324, 0.0922, 0.0160, -0.0934, -0.1700, 0.1663}, |
524 | {0.0674, |
525 | -0.0056, |
526 | 0.1324, |
527 | 0.0922, |
528 | 0.0160, |
529 | -0.0934, |
530 | -0.1700, |
531 | 0.1663}}}); |
532 | ASSERT_TRUE( |
533 | torch::allclose(attn_output, attn_output_expected, 1e-05, 2e-04)); |
534 | |
535 | auto attn_output_weights = std::get<1>(output); |
536 | auto attn_output_weights_expected = torch::tensor( |
537 | {{{0.3333, 0.3333, 0.3333}}, |
538 | {{0.3333, 0.3333, 0.3333}}, |
539 | {{0.3333, 0.3333, 0.3333}}, |
540 | {{0.3333, 0.3333, 0.3333}}, |
541 | {{0.3333, 0.3333, 0.3333}}, |
542 | {{0.3333, 0.3333, 0.3333}}, |
543 | {{0.3333, 0.3333, 0.3333}}, |
544 | {{0.3333, 0.3333, 0.3333}}}); |
545 | ASSERT_TRUE(torch::allclose( |
546 | attn_output_weights, attn_output_weights_expected, 1e-05, 2e-04)); |
547 | } |
548 | { |
549 | auto indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); |
550 | auto x = torch::tensor({{{2, 4, 5}}}, torch::dtype(torch::kFloat)); |
551 | Sequential sequential(MaxUnpool1d(3)); |
552 | auto y = sequential->forward(x, indices); |
553 | auto expected = |
554 | torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat); |
555 | ASSERT_TRUE(torch::allclose(y, expected)); |
556 | } |
557 | { |
558 | auto indices = torch::tensor( |
559 | {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, |
560 | {{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}}, |
561 | torch::kLong); |
562 | auto x = torch::tensor( |
563 | {{{{6, 8, 9}, {16, 18, 19}, {21, 23, 24}}}, |
564 | {{{31, 33, 34}, {41, 43, 44}, {46, 48, 49}}}}, |
565 | torch::dtype(torch::kFloat)); |
566 | Sequential sequential( |
567 | MaxUnpool2d(MaxUnpool2dOptions(3).stride(2).padding(1))); |
568 | auto y = sequential->forward(x, indices); |
569 | auto expected = torch::tensor( |
570 | {{{{0, 0, 0, 0, 0}, |
571 | {0, 6, 0, 8, 9}, |
572 | {0, 0, 0, 0, 0}, |
573 | {0, 16, 0, 18, 19}, |
574 | {0, 21, 0, 23, 24}}}, |
575 | {{{0, 0, 0, 0, 0}, |
576 | {0, 31, 0, 33, 34}, |
577 | {0, 0, 0, 0, 0}, |
578 | {0, 41, 0, 43, 44}, |
579 | {0, 46, 0, 48, 49}}}}, |
580 | torch::kFloat); |
581 | ASSERT_TRUE(torch::allclose(y, expected)); |
582 | } |
583 | { |
584 | auto indices = torch::tensor({{{{{26}}}}}, torch::kLong); |
585 | auto x = torch::tensor( |
586 | {{{{{26}}}}}, torch::dtype(torch::kFloat).requires_grad(true)); |
587 | Sequential sequential(MaxUnpool3d(3)); |
588 | auto y = sequential->forward(x, indices); |
589 | auto expected = torch::tensor( |
590 | {{{{{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, |
591 | {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}, |
592 | {{0, 0, 0}, {0, 0, 0}, {0, 0, 26}}}}}, |
593 | torch::kFloat); |
594 | ASSERT_TRUE(torch::allclose(y, expected)); |
595 | } |
596 | { |
597 | torch::manual_seed(0); |
598 | Sequential sequential(Identity(), RNN(2, 3)); |
599 | auto x = torch::ones({2, 3, 2}); |
600 | auto rnn_output = |
601 | sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x); |
602 | auto expected_output = torch::tensor( |
603 | {{{-0.0645, -0.7274, 0.4531}, |
604 | {-0.0645, -0.7274, 0.4531}, |
605 | {-0.0645, -0.7274, 0.4531}}, |
606 | {{-0.3970, -0.6950, 0.6009}, |
607 | {-0.3970, -0.6950, 0.6009}, |
608 | {-0.3970, -0.6950, 0.6009}}}); |
609 | ASSERT_TRUE(torch::allclose( |
610 | std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); |
611 | } |
612 | { |
613 | torch::manual_seed(0); |
614 | Sequential sequential(Identity(), LSTM(2, 3)); |
615 | auto x = torch::ones({2, 3, 2}); |
616 | auto rnn_output = sequential->forward< |
617 | std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x); |
618 | auto expected_output = torch::tensor( |
619 | {{{-0.2693, -0.1240, 0.0744}, |
620 | {-0.2693, -0.1240, 0.0744}, |
621 | {-0.2693, -0.1240, 0.0744}}, |
622 | {{-0.3889, -0.1919, 0.1183}, |
623 | {-0.3889, -0.1919, 0.1183}, |
624 | {-0.3889, -0.1919, 0.1183}}}); |
625 | ASSERT_TRUE(torch::allclose( |
626 | std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); |
627 | } |
628 | { |
629 | torch::manual_seed(0); |
630 | Sequential sequential(Identity(), GRU(2, 3)); |
631 | auto x = torch::ones({2, 3, 2}); |
632 | auto rnn_output = |
633 | sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x); |
634 | auto expected_output = torch::tensor( |
635 | {{{-0.1134, 0.0467, 0.2336}, |
636 | {-0.1134, 0.0467, 0.2336}, |
637 | {-0.1134, 0.0467, 0.2336}}, |
638 | {{-0.1189, 0.0502, 0.2960}, |
639 | {-0.1189, 0.0502, 0.2960}, |
640 | {-0.1189, 0.0502, 0.2960}}}); |
641 | ASSERT_TRUE(torch::allclose( |
642 | std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); |
643 | } |
644 | { |
645 | torch::manual_seed(0); |
646 | Sequential sequential(Identity(), RNNCell(2, 3)); |
647 | auto x = torch::ones({2, 2}); |
648 | auto rnn_output = sequential->forward<torch::Tensor>(x); |
649 | auto expected_output = |
650 | torch::tensor({{-0.0645, -0.7274, 0.4531}, {-0.0645, -0.7274, 0.4531}}); |
651 | ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04)); |
652 | } |
653 | { |
654 | torch::manual_seed(0); |
655 | Sequential sequential(Identity(), LSTMCell(2, 3)); |
656 | auto x = torch::ones({2, 2}); |
657 | auto rnn_output = |
658 | sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x); |
659 | auto expected_output = |
660 | torch::tensor({{-0.2693, -0.1240, 0.0744}, {-0.2693, -0.1240, 0.0744}}); |
661 | ASSERT_TRUE(torch::allclose( |
662 | std::get<0>(rnn_output), expected_output, 1e-05, 2e-04)); |
663 | } |
664 | { |
665 | torch::manual_seed(0); |
666 | Sequential sequential(Identity(), GRUCell(2, 3)); |
667 | auto x = torch::ones({2, 2}); |
668 | auto rnn_output = sequential->forward<torch::Tensor>(x); |
669 | auto expected_output = |
670 | torch::tensor({{-0.1134, 0.0467, 0.2336}, {-0.1134, 0.0467, 0.2336}}); |
671 | ASSERT_TRUE(torch::allclose(rnn_output, expected_output, 1e-05, 2e-04)); |
672 | } |
673 | } |
674 | |