1#include <c10/core/Device.h>
2#include <c10/core/DeviceType.h>
3#include <gtest/gtest.h>
4#include <test/cpp/lazy/test_lazy_ops_util.h>
5#include <torch/csrc/lazy/core/debug_util.h>
6#include <torch/csrc/lazy/core/helpers.h>
7#include <torch/csrc/lazy/core/ir_builder.h>
8#include <torch/csrc/lazy/core/lazy_graph_executor.h>
9#include <torch/csrc/lazy/core/metrics.h>
10#include <torch/csrc/lazy/core/permutation_util.h>
11#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
12#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
13#include <torch/torch.h>
14#include <iostream>
15
16namespace torch {
17namespace lazy {
18
19// Lazy Tensor is disabled in FBCODE until addressing non-virtual methods (e.g.
20// sizes) in TensorImpl
21#ifndef FBCODE_CAFFE2
22
23namespace {
24// This registers the torchscript backend, without which lazy device won't work.
25// FIXME: This registers the backend for the whole test binary. We should
26// probably do it and undo it in the test fixture below.
27static bool inline init_backend() {
28 torch::lazy::InitTorchScriptBackend();
29 return true;
30}
31static const bool backend_initialized = init_backend();
32
33} // namespace
34
35class LazyTsTest : public ::testing::Test {
36 protected:
37 void SetUp() override;
38
39 void TearDown() override;
40
41 static void CommonSetup() {}
42
43 void ExpectCounterNotChanged(
44 const std::string& counter_regex,
45 const std::unordered_set<std::string>* ignore_set) {}
46
47 void ExpectCounterChanged(
48 const std::string& counter_regex,
49 const std::unordered_set<std::string>* ignore_set) {}
50
51 void ResetCounters() {}
52
53 private:
54 void MakeEndSnapshot() {}
55};
56
57class LazyOpsTestBase : public LazyTsTest {
58 protected:
59 static void SetUpTestCase() {}
60};
61
62void LazyTsTest::SetUp() {
63 (void)backend_initialized; // avoid unused parameter warning
64 at::manual_seed(42);
65 torch::lazy::LazyGraphExecutor::Get()->SetRngSeed(
66 torch::lazy::BackendDevice(), 42);
67}
68
69void LazyTsTest::TearDown() {}
70
71namespace {
72using torch::lazy::DebugUtil;
73
74class LazyOpsTest : public LazyOpsTestBase {};
75
76static inline bool IsCuda() {
77 return torch::lazy::getBackend()->EagerFallbackDeviceType() == at::kCUDA;
78}
79
80static inline at::DeviceType DefaultDevice() {
81 return torch::lazy::getBackend()->EagerFallbackDeviceType();
82}
83
84} // namespace
85
86TEST_F(LazyOpsTest, TestScalarTensor) {
87 torch::Tensor scalar_tensor = torch::scalar_tensor(
88 1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
89 ForEachDevice([&](const torch::Device& device) {
90 torch::Tensor lazy_scalar_tensor = torch::scalar_tensor(
91 1., torch::TensorOptions(torch::kFloat).device(torch::kLazy));
92 AllClose(scalar_tensor, lazy_scalar_tensor);
93 });
94}
95
96TEST_F(LazyOpsTest, TestClone) {
97 ForEachDevice([&](const torch::Device& device) {
98 torch::Tensor a = torch::rand(
99 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
100 torch::Tensor lazy_a = CopyToDevice(a, device);
101 torch::Tensor lazy_b = lazy_a.clone();
102 AllClose(a, lazy_b);
103 lazy_a.add_(1.0);
104 AllClose(a, lazy_b);
105 });
106}
107
108TEST_F(LazyOpsTest, TestTo) {
109 ForEachDevice([&](const torch::Device& device) {
110 torch::Tensor a = torch::rand(
111 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
112 torch::Tensor lazy_a = CopyToDevice(a, device);
113 AllClose(a, lazy_a);
114 });
115}
116
117TEST_F(LazyOpsTest, TestIsFloatingPoint) {
118 ForEachDevice([&](const torch::Device& device) {
119 torch::Tensor a = torch::rand(
120 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
121 torch::Tensor lazy_a = CopyToDevice(a, device);
122 bool is_float = torch::is_floating_point(a);
123 bool lazy_is_float = torch::is_floating_point(lazy_a);
124 EXPECT_EQ(is_float, lazy_is_float);
125 });
126}
127
128TEST_F(LazyOpsTest, TestIsSigned) {
129 ForEachDevice([&](const torch::Device& device) {
130 torch::Tensor a = torch::rand(
131 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
132 torch::Tensor lazy_a = CopyToDevice(a, device);
133 bool is_signed = torch::is_signed(a);
134 bool lazy_is_signed = torch::is_signed(lazy_a);
135 EXPECT_EQ(is_signed, lazy_is_signed);
136 });
137}
138
139TEST_F(LazyOpsTest, TestCastByte) {
140 torch::Tensor a =
141 torch::rand(
142 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
143 100.0;
144 torch::Tensor b = torch::_cast_Byte(a);
145 ForEachDevice([&](const torch::Device& device) {
146 torch::Tensor lazy_a = CopyToDevice(a, device);
147 torch::Tensor lazy_b = torch::_cast_Byte(lazy_a);
148 AllEqual(b, lazy_b);
149 });
150}
151
152TEST_F(LazyOpsTest, TestCastChar) {
153 torch::Tensor a =
154 torch::rand(
155 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
156 100.0;
157 torch::Tensor b = torch::_cast_Char(a);
158 ForEachDevice([&](const torch::Device& device) {
159 torch::Tensor lazy_a = CopyToDevice(a, device);
160 torch::Tensor lazy_b = torch::_cast_Char(lazy_a);
161 AllEqual(b, lazy_b);
162 });
163}
164
165TEST_F(LazyOpsTest, TestCastShort) {
166 torch::Tensor a =
167 torch::rand(
168 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
169 100.0;
170 torch::Tensor b = torch::_cast_Short(a);
171 ForEachDevice([&](const torch::Device& device) {
172 torch::Tensor lazy_a = CopyToDevice(a, device);
173 torch::Tensor lazy_b = torch::_cast_Short(lazy_a);
174 AllEqual(b, lazy_b);
175 });
176}
177
178TEST_F(LazyOpsTest, TestCastInt) {
179 torch::Tensor a =
180 torch::rand(
181 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
182 100.0;
183 torch::Tensor b = torch::_cast_Int(a);
184 ForEachDevice([&](const torch::Device& device) {
185 torch::Tensor lazy_a = CopyToDevice(a, device);
186 torch::Tensor lazy_b = torch::_cast_Int(lazy_a);
187 AllEqual(b, lazy_b);
188 });
189}
190
191TEST_F(LazyOpsTest, TestCastLong) {
192 torch::Tensor a =
193 torch::rand(
194 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
195 100.0;
196 torch::Tensor b = torch::_cast_Long(a);
197 ForEachDevice([&](const torch::Device& device) {
198 torch::Tensor lazy_a = CopyToDevice(a, device);
199 torch::Tensor lazy_b = torch::_cast_Long(lazy_a);
200 AllEqual(b, lazy_b);
201 });
202}
203
204TEST_F(LazyOpsTest, TestCastFloat) {
205 torch::Tensor a =
206 torch::rand(
207 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
208 100.0;
209 torch::Tensor b = torch::_cast_Float(a);
210 ForEachDevice([&](const torch::Device& device) {
211 torch::Tensor lazy_a = CopyToDevice(a, device);
212 torch::Tensor lazy_b = torch::_cast_Float(lazy_a);
213 AllEqual(b, lazy_b);
214 });
215}
216
217TEST_F(LazyOpsTest, TestRetainType) {
218 torch::Tensor lazy_a = torch::zeros(
219 {2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
220 torch::Tensor lazy_b = torch::ones(
221 {2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
222 torch::Tensor lazy_c = lazy_a + lazy_b;
223 EXPECT_EQ(lazy_c.scalar_type(), torch::ScalarType::Byte);
224}
225
226TEST_F(LazyOpsTest, TestLogicalTypeWithInterop) {
227 torch::Tensor query = torch::rand(
228 {2, 12, 20, 64},
229 torch::TensorOptions(torch::kFloat).device(torch::kLazy));
230 torch::Tensor key = torch::rand(
231 {2, 12, 64, 20},
232 torch::TensorOptions(torch::kFloat).device(torch::kLazy));
233 torch::Tensor scores =
234 torch::matmul(query, key) /
235 torch::scalar_tensor(
236 8, torch::TensorOptions(torch::kDouble).device(torch::kLazy));
237 torch::Tensor p_attn = torch::softmax(scores, /*dim=*/-1);
238 EXPECT_EQ(p_attn.scalar_type(), torch::ScalarType::Float);
239}
240
241TEST_F(LazyOpsTest, TestAdd) {
242 torch::Tensor a = torch::rand(
243 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
244 torch::Tensor b = torch::rand(
245 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
246 torch::Tensor c = torch::add(a, b);
247 ForEachDevice([&](const torch::Device& device) {
248 torch::Tensor lazy_a = CopyToDevice(a, device);
249 torch::Tensor lazy_b = CopyToDevice(b, device);
250 torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
251 AllClose(c, lazy_c);
252 });
253}
254
255TEST_F(LazyOpsTest, TestAddHalf) {
256 torch::Tensor a = torch::rand(
257 {2, 2}, torch::TensorOptions(torch::kHalf).device(DefaultDevice()));
258 torch::Tensor b = torch::rand(
259 {2, 2}, torch::TensorOptions(torch::kHalf).device(DefaultDevice()));
260 torch::Tensor c = torch::add(a, b);
261 ForEachDevice([&](const torch::Device& device) {
262 torch::Tensor lazy_a = CopyToDevice(a, device);
263 torch::Tensor lazy_b = CopyToDevice(b, device);
264 torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
265 AllClose(c, lazy_c);
266 });
267}
268
269TEST_F(LazyOpsTest, TestAddMixedPrecision) {
270 torch::Tensor a = torch::rand(
271 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
272 torch::Tensor b = torch::rand(
273 {2, 2}, torch::TensorOptions(torch::kHalf).device(DefaultDevice()));
274 torch::Tensor c = torch::add(a, b);
275 ForEachDevice([&](const torch::Device& device) {
276 torch::Tensor lazy_a = CopyToDevice(a, device);
277 torch::Tensor lazy_b = CopyToDevice(b, device);
278 torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
279 AllClose(c, lazy_c);
280 });
281}
282
283TEST_F(LazyOpsTest, TestAddInPlace) {
284 ForEachDevice([&](const torch::Device& device) {
285 torch::Tensor a = torch::rand(
286 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
287 torch::Tensor lazy_a = CopyToDevice(a, device);
288 torch::Tensor b = torch::rand(
289 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
290 torch::Tensor lazy_b = CopyToDevice(b, device);
291 torch::Tensor c = a.add_(b);
292 torch::Tensor lazy_c = lazy_a.add_(lazy_b);
293 AllClose(a, lazy_a);
294 AllClose(c, lazy_c);
295 });
296}
297
298TEST_F(LazyOpsTest, TestAddScalar) {
299 torch::Tensor a = torch::rand(
300 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
301 torch::Scalar b(1);
302 torch::Tensor c = torch::add(a, b);
303 ForEachDevice([&](const torch::Device& device) {
304 torch::Tensor lazy_a = CopyToDevice(a, device);
305 torch::Tensor lazy_c = torch::add(lazy_a, b);
306 AllClose(c, lazy_c);
307 });
308}
309
310TEST_F(LazyOpsTest, TestAddScalarInPlace) {
311 torch::Scalar b(1);
312 ForEachDevice([&](const torch::Device& device) {
313 torch::Tensor a = torch::rand(
314 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
315 torch::Tensor lazy_a = CopyToDevice(a, device);
316 torch::Tensor c = a.add_(b);
317 torch::Tensor lazy_c = lazy_a.add_(b);
318 AllClose(a, lazy_a);
319 AllClose(c, lazy_c);
320 });
321}
322
323TEST_F(LazyOpsTest, TestAddZeroSizeDim) {
324 torch::Tensor a = torch::rand(
325 {0, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
326 torch::Tensor b = torch::rand(
327 {1, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
328 torch::Tensor c = torch::add(a, b);
329 ForEachDevice([&](const torch::Device& device) {
330 torch::Tensor lazy_a = CopyToDevice(a, device);
331 torch::Tensor lazy_b = CopyToDevice(b, device);
332 torch::Tensor lazy_c = torch::add(lazy_a, lazy_b);
333 AllClose(c, lazy_c);
334 });
335}
336
337TEST_F(LazyOpsTest, TestSub) {
338 torch::Tensor a = torch::rand(
339 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
340 torch::Tensor b = torch::rand(
341 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
342 torch::Tensor c = torch::sub(a, b);
343 ForEachDevice([&](const torch::Device& device) {
344 torch::Tensor lazy_a = CopyToDevice(a, device);
345 torch::Tensor lazy_b = CopyToDevice(b, device);
346 torch::Tensor lazy_c = torch::sub(lazy_a, lazy_b);
347 AllClose(c, lazy_c);
348 });
349}
350
351TEST_F(LazyOpsTest, TestSubInPlace) {
352 ForEachDevice([&](const torch::Device& device) {
353 torch::Tensor a = torch::rand(
354 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
355 torch::Tensor lazy_a = CopyToDevice(a, device);
356 torch::Tensor b = torch::rand(
357 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
358 torch::Tensor lazy_b = CopyToDevice(b, device);
359 torch::Tensor c = a.sub_(b);
360 torch::Tensor lazy_c = lazy_a.sub_(lazy_b);
361 AllClose(a, lazy_a);
362 AllClose(c, lazy_c);
363 });
364}
365
366TEST_F(LazyOpsTest, TestSubScalar) {
367 torch::Tensor a = torch::rand(
368 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
369 torch::Scalar b(1);
370 torch::Tensor c = torch::sub(a, b);
371 ForEachDevice([&](const torch::Device& device) {
372 torch::Tensor lazy_a = CopyToDevice(a, device);
373 torch::Tensor lazy_c = torch::sub(lazy_a, b);
374 AllClose(c, lazy_c);
375 });
376}
377
378TEST_F(LazyOpsTest, TestSubScalarInPlace) {
379 torch::Scalar b(1);
380 ForEachDevice([&](const torch::Device& device) {
381 torch::Tensor a = torch::rand(
382 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
383 torch::Tensor lazy_a = CopyToDevice(a, device);
384 torch::Tensor c = a.sub_(b);
385 torch::Tensor lazy_c = lazy_a.sub_(b);
386 AllClose(a, lazy_a);
387 AllClose(c, lazy_c);
388 });
389}
390
391TEST_F(LazyOpsTest, TestMul) {
392 torch::Tensor a = torch::rand(
393 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
394 torch::Tensor b = torch::rand(
395 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
396 torch::Tensor c = torch::mul(a, b);
397 ForEachDevice([&](const torch::Device& device) {
398 torch::Tensor lazy_a = CopyToDevice(a, device);
399 torch::Tensor lazy_b = CopyToDevice(b, device);
400 torch::Tensor lazy_c = torch::mul(lazy_a, lazy_b);
401 AllClose(c, lazy_c);
402 });
403}
404
405TEST_F(LazyOpsTest, TestMulInPlace) {
406 ForEachDevice([&](const torch::Device& device) {
407 torch::Tensor a = torch::rand(
408 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
409 torch::Tensor lazy_a = CopyToDevice(a, device);
410 torch::Tensor b = torch::rand(
411 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
412 torch::Tensor lazy_b = CopyToDevice(b, device);
413 torch::Tensor c = a.mul_(b);
414 torch::Tensor lazy_c = lazy_a.mul_(lazy_b);
415 AllClose(a, lazy_a);
416 AllClose(c, lazy_c);
417 });
418}
419
420TEST_F(LazyOpsTest, TestMulScalar) {
421 torch::Tensor a = torch::rand(
422 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
423 torch::Scalar b(3);
424 torch::Tensor c = torch::mul(a, b);
425 ForEachDevice([&](const torch::Device& device) {
426 torch::Tensor lazy_a = CopyToDevice(a, device);
427 torch::Tensor lazy_c = torch::mul(lazy_a, b);
428 AllClose(c, lazy_c);
429 });
430}
431
432TEST_F(LazyOpsTest, TestMulScalarInPlace) {
433 torch::Scalar b(3);
434 ForEachDevice([&](const torch::Device& device) {
435 torch::Tensor a = torch::rand(
436 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
437 torch::Tensor lazy_a = CopyToDevice(a, device);
438 torch::Tensor c = a.mul_(b);
439 torch::Tensor lazy_c = lazy_a.mul_(b);
440 AllClose(a, lazy_a);
441 AllClose(c, lazy_c);
442 });
443}
444
445TEST_F(LazyOpsTest, TestDiv) {
446 for (torch::ScalarType scalar_type1 :
447 {torch::kFloat,
448 torch::kByte,
449 torch::kChar,
450 torch::kShort,
451 torch::kInt,
452 torch::kLong}) {
453 torch::Tensor a = isFloatingType(scalar_type1)
454 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
455 : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1));
456 for (torch::ScalarType scalar_type2 :
457 {torch::kFloat,
458 torch::kByte,
459 torch::kChar,
460 torch::kShort,
461 torch::kInt,
462 torch::kLong}) {
463 torch::Tensor b = isFloatingType(scalar_type2)
464 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
465 : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
466 torch::Tensor c = torch::div(a, b);
467 ForEachDevice([&](const torch::Device& device) {
468 torch::Tensor lazy_a = CopyToDevice(a, device);
469 torch::Tensor lazy_b = CopyToDevice(b, device);
470 torch::Tensor lazy_c = torch::div(lazy_a, lazy_b);
471 AllClose(c, lazy_c);
472 });
473 }
474 }
475}
476
477TEST_F(LazyOpsTest, TestDivWithRoundingMode) {
478 c10::optional<c10::string_view> rounding_modes[] = {
479 "trunc", "floor", c10::nullopt};
480 for (const auto& rounding_mode : rounding_modes) {
481 for (torch::ScalarType scalar_type1 :
482 {torch::kFloat,
483 torch::kByte,
484 torch::kChar,
485 torch::kShort,
486 torch::kInt,
487 torch::kLong}) {
488 int lower_bound = (scalar_type1 == torch::kByte) ? 0 : -100;
489 torch::Tensor a = isFloatingType(scalar_type1)
490 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
491 : torch::randint(
492 lower_bound, 50, {3, 4}, torch::TensorOptions(scalar_type1));
493 for (torch::ScalarType scalar_type2 :
494 {torch::kFloat,
495 torch::kByte,
496 torch::kChar,
497 torch::kShort,
498 torch::kInt,
499 torch::kLong}) {
500 torch::Tensor b = isFloatingType(scalar_type2)
501 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
502 : torch::randint(
503 51, 100, {3, 4}, torch::TensorOptions(scalar_type2));
504 torch::Tensor c = torch::div(a, b, rounding_mode);
505 ForEachDevice([&](const torch::Device& device) {
506 torch::Tensor lazy_a = CopyToDevice(a, device);
507 torch::Tensor lazy_b = CopyToDevice(b, device);
508 torch::Tensor lazy_c = torch::div(lazy_a, lazy_b, rounding_mode);
509 AllClose(c, lazy_c);
510 });
511 }
512 }
513 }
514}
515
516TEST_F(LazyOpsTest, TestDivInPlace) {
517 for (torch::ScalarType scalar_type1 : {torch::kFloat}) {
518 torch::Tensor a = isFloatingType(scalar_type1)
519 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
520 : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1));
521 for (torch::ScalarType scalar_type2 : {torch::kFloat}) {
522 torch::Tensor b = isFloatingType(scalar_type2)
523 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
524 : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
525 ForEachDevice([&](const torch::Device& device) {
526 torch::Tensor lazy_a = CopyToDevice(a, device);
527 torch::Tensor c = a.div_(b);
528 torch::Tensor lazy_b = CopyToDevice(b, device);
529 torch::Tensor lazy_c = lazy_a.div_(lazy_b);
530 ;
531 AllClose(c, lazy_c);
532 });
533 }
534 }
535}
536
537TEST_F(LazyOpsTest, TestDivInPlaceWithRoundingMode) {
538 c10::optional<c10::string_view> rounding_modes[] = {
539 "trunc", "floor", c10::nullopt};
540 for (const auto& rounding_mode : rounding_modes) {
541 for (torch::ScalarType scalar_type1 : {torch::kFloat}) {
542 torch::Tensor a = isFloatingType(scalar_type1)
543 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
544 : torch::randint(
545 -100, 100, {3, 4}, torch::TensorOptions(scalar_type1));
546 for (torch::ScalarType scalar_type2 : {torch::kFloat}) {
547 torch::Tensor b = isFloatingType(scalar_type2)
548 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
549 : torch::randint(
550 1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
551 ForEachDevice([&](const torch::Device& device) {
552 torch::Tensor lazy_a = CopyToDevice(a, device);
553 torch::Tensor c = a.div_(b, rounding_mode);
554 torch::Tensor lazy_b = CopyToDevice(b, device);
555 torch::Tensor lazy_c = lazy_a.div_(lazy_b, rounding_mode);
556 AllClose(c, lazy_c);
557 });
558 }
559 }
560 }
561}
562
563TEST_F(LazyOpsTest, TestDivScalar) {
564 for (torch::ScalarType scalar_type :
565 {torch::kFloat,
566 torch::kByte,
567 torch::kChar,
568 torch::kShort,
569 torch::kInt,
570 torch::kLong}) {
571 torch::Tensor a = isFloatingType(scalar_type)
572 ? torch::rand(
573 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
574 : torch::randint(
575 1,
576 100,
577 {3, 4},
578 torch::TensorOptions(scalar_type).device(DefaultDevice()));
579 for (bool is_float : {true, false}) {
580 torch::Scalar b = is_float ? torch::Scalar(3.0) : torch::Scalar(3);
581 torch::Tensor c = torch::div(a, b);
582 ForEachDevice([&](const torch::Device& device) {
583 torch::Tensor lazy_a = CopyToDevice(a, device);
584 torch::Tensor lazy_c = torch::div(lazy_a, b);
585 AllClose(c, lazy_c);
586 });
587 }
588 }
589}
590
591TEST_F(LazyOpsTest, TestDivScalarInPlace) {
592 for (torch::ScalarType scalar_type : {torch::kFloat}) {
593 torch::Tensor a = isFloatingType(scalar_type)
594 ? torch::rand(
595 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
596 : torch::randint(
597 1,
598 100,
599 {3, 4},
600 torch::TensorOptions(scalar_type).device(DefaultDevice()));
601 for (bool is_float : {true, false}) {
602 torch::Scalar b = is_float ? torch::Scalar(3.0) : torch::Scalar(3);
603 ForEachDevice([&](const torch::Device& device) {
604 torch::Tensor lazy_a = CopyToDevice(a, device);
605 torch::Tensor c = a.div_(b);
606 torch::Tensor lazy_c = lazy_a.div_(b);
607 AllClose(c, lazy_c);
608 });
609 }
610 }
611}
612
613TEST_F(LazyOpsTest, TestDivOut) {
614 for (torch::ScalarType scalar_type : {torch::kFloat, torch::kDouble}) {
615 torch::Tensor a = torch::rand(
616 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
617 torch::Tensor b = torch::rand(
618 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
619 torch::Tensor c = torch::empty(
620 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
621 torch::div_out(c, a, b);
622 ForEachDevice([&](const torch::Device& device) {
623 torch::Tensor lazy_a = CopyToDevice(a, device);
624 torch::Tensor lazy_b = CopyToDevice(b, device);
625 torch::Tensor lazy_c = torch::empty({3, 4}, lazy_b.options());
626 torch::div_out(lazy_c, lazy_a, lazy_b);
627 AllClose(c, lazy_c);
628 });
629 }
630}
631
632TEST_F(LazyOpsTest, TestRsubScalar) {
633 torch::Tensor input = torch::rand(
634 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
635 torch::Scalar other(1.5);
636 torch::Scalar alpha(2.5);
637 torch::Tensor result = torch::rsub(input, other, alpha);
638 ForEachDevice([&](const torch::Device& device) {
639 torch::Tensor lazy_input = CopyToDevice(input, device);
640 torch::Tensor lazy_result = torch::rsub(lazy_input, other, alpha);
641 AllClose(result, lazy_result);
642 });
643}
644
645TEST_F(LazyOpsTest, TestNe) {
646 torch::Tensor a = torch::rand(
647 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
648 torch::Tensor b = torch::rand(
649 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
650 torch::Tensor c = torch::ne(a, b);
651 ForEachDevice([&](const torch::Device& device) {
652 torch::Tensor lazy_a = CopyToDevice(a, device);
653 torch::Tensor lazy_b = CopyToDevice(b, device);
654 torch::Tensor lazy_c = torch::ne(lazy_a, lazy_b);
655 AllEqual(c, lazy_c);
656 });
657}
658
659TEST_F(LazyOpsTest, TestNeInplace) {
660 torch::Tensor a = torch::rand(
661 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
662 torch::Tensor a_copy = a.clone();
663 torch::Tensor b = a.clone();
664 b[0] += 1;
665 a.ne_(b);
666 ForEachDevice([&](const torch::Device& device) {
667 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
668 torch::Tensor lazy_b = CopyToDevice(b, device);
669 lazy_a.ne_(lazy_b);
670 AllClose(a, lazy_a);
671 });
672}
673
674TEST_F(LazyOpsTest, TestEq) {
675 torch::Tensor a = torch::rand(
676 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
677 torch::Tensor b = a.clone();
678 torch::Tensor c = torch::eq(a, b);
679 ForEachDevice([&](const torch::Device& device) {
680 torch::Tensor lazy_a = CopyToDevice(a, device);
681 torch::Tensor lazy_b = CopyToDevice(b, device);
682 torch::Tensor lazy_c = torch::eq(lazy_a, lazy_b);
683 AllEqual(c, lazy_c);
684 });
685}
686
687TEST_F(LazyOpsTest, TestEqInplace) {
688 torch::Tensor a = torch::rand(
689 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
690 torch::Tensor b = a.clone();
691 b[0] += 1;
692 torch::Tensor a_copy = a.clone();
693 a.eq_(b);
694 ForEachDevice([&](const torch::Device& device) {
695 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
696 torch::Tensor lazy_b = CopyToDevice(b, device);
697 lazy_a.eq_(lazy_b);
698 AllClose(lazy_a, a);
699 });
700}
701
702TEST_F(LazyOpsTest, TestGe) {
703 torch::Tensor a = torch::rand(
704 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
705 torch::Tensor b = a.clone();
706 torch::Tensor c = torch::ge(a, b);
707 ForEachDevice([&](const torch::Device& device) {
708 torch::Tensor lazy_a = CopyToDevice(a, device);
709 torch::Tensor lazy_b = CopyToDevice(b, device);
710 torch::Tensor lazy_c = torch::ge(lazy_a, lazy_b);
711 AllEqual(c, lazy_c);
712 });
713}
714
715TEST_F(LazyOpsTest, TestGeInplace) {
716 torch::Tensor a = torch::rand(
717 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
718 torch::Tensor b = a.clone();
719 b[0] += 1;
720 b[1] -= 1;
721 torch::Tensor a_copy = a.clone();
722 a.ge_(b);
723 ForEachDevice([&](const torch::Device& device) {
724 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
725 torch::Tensor lazy_b = CopyToDevice(b, device);
726 lazy_a.ge_(lazy_b);
727 AllClose(lazy_a, a);
728 });
729}
730
731TEST_F(LazyOpsTest, TestLe) {
732 torch::Tensor a = torch::rand(
733 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
734 torch::Tensor b = a.clone();
735 torch::Tensor c = torch::le(a, b);
736 ForEachDevice([&](const torch::Device& device) {
737 torch::Tensor lazy_a = CopyToDevice(a, device);
738 torch::Tensor lazy_b = CopyToDevice(b, device);
739 torch::Tensor lazy_c = torch::le(lazy_a, lazy_b);
740 AllEqual(c, lazy_c);
741 });
742}
743
744TEST_F(LazyOpsTest, TestLeInplace) {
745 torch::Tensor a = torch::rand(
746 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
747 torch::Tensor b = a.clone();
748 b[0] += 1;
749 b[1] -= 1;
750 torch::Tensor a_copy = a.clone();
751 a.le_(b);
752 ForEachDevice([&](const torch::Device& device) {
753 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
754 torch::Tensor lazy_b = CopyToDevice(b, device);
755 lazy_a.le_(lazy_b);
756 AllClose(lazy_a, a);
757 });
758}
759
760TEST_F(LazyOpsTest, TestGt) {
761 torch::Tensor a = torch::rand(
762 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
763 torch::Tensor b = torch::add(a.clone(), torch::ones_like(a));
764 torch::Tensor c = torch::gt(b, a);
765 ForEachDevice([&](const torch::Device& device) {
766 torch::Tensor lazy_a = CopyToDevice(a, device);
767 torch::Tensor lazy_b = CopyToDevice(b, device);
768 torch::Tensor lazy_c = torch::gt(lazy_b, lazy_a);
769 AllEqual(c, lazy_c);
770 });
771}
772
773TEST_F(LazyOpsTest, TestGtInplace) {
774 torch::Tensor a = torch::rand(
775 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
776 torch::Tensor b = a.clone();
777 b[0] += 1;
778 b[1] -= 1;
779 torch::Tensor a_copy = a.clone();
780 a.gt_(b);
781 ForEachDevice([&](const torch::Device& device) {
782 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
783 torch::Tensor lazy_b = CopyToDevice(b, device);
784 lazy_a.gt_(lazy_b);
785 AllClose(lazy_a, a);
786 });
787}
788
789TEST_F(LazyOpsTest, TestLt) {
790 torch::Tensor a = torch::rand(
791 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
792 torch::Tensor b = torch::add(a.clone(), torch::ones_like(a));
793 torch::Tensor c = torch::lt(a, b);
794 ForEachDevice([&](const torch::Device& device) {
795 torch::Tensor lazy_a = CopyToDevice(a, device);
796 torch::Tensor lazy_b = CopyToDevice(b, device);
797 torch::Tensor lazy_c = torch::lt(lazy_a, lazy_b);
798 AllEqual(c, lazy_c);
799 });
800}
801
802TEST_F(LazyOpsTest, TestLtInplace) {
803 torch::Tensor a = torch::rand(
804 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
805 torch::Tensor b = a.clone();
806 b[0] += 1;
807 b[1] -= 1;
808 torch::Tensor a_copy = a.clone();
809 a.lt_(b);
810 ForEachDevice([&](const torch::Device& device) {
811 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
812 torch::Tensor lazy_b = CopyToDevice(b, device);
813 lazy_a.lt_(lazy_b);
814 AllClose(lazy_a, a);
815 });
816}
817
818TEST_F(LazyOpsTest, TestNeScalar) {
819 torch::Tensor input = torch::ones({2, 3});
820 torch::Scalar other(float(0));
821 torch::Tensor result = torch::ne(input, other);
822 ForEachDevice([&](const torch::Device& device) {
823 torch::Tensor lazy_input = CopyToDevice(input, device);
824 torch::Tensor lazy_result = torch::ne(lazy_input, other);
825 AllEqual(result, lazy_result);
826 });
827}
828
829TEST_F(LazyOpsTest, TestEqScalar) {
830 torch::Tensor input = torch::ones({2, 3});
831 torch::Scalar other(float(1));
832 torch::Tensor result = torch::eq(input, other);
833 ForEachDevice([&](const torch::Device& device) {
834 torch::Tensor lazy_input = CopyToDevice(input, device);
835 torch::Tensor lazy_result = torch::eq(lazy_input, other);
836 AllEqual(result, lazy_result);
837 });
838}
839
840TEST_F(LazyOpsTest, TestGeScalar) {
841 torch::Tensor input = torch::ones({2, 3});
842 torch::Scalar other(float(1));
843 torch::Tensor result = torch::ge(input, other);
844 ForEachDevice([&](const torch::Device& device) {
845 torch::Tensor lazy_input = CopyToDevice(input, device);
846 torch::Tensor lazy_result = torch::ge(lazy_input, other);
847 AllEqual(result, lazy_result);
848 });
849}
850
851TEST_F(LazyOpsTest, TestGeScalarInplace) {
852 torch::Tensor input = torch::arange(
853 -1.,
854 1.5,
855 0.5,
856 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
857 torch::Scalar other(float(0));
858 torch::Tensor input_copy = input.clone();
859 input.ge_(other);
860 ForEachDevice([&](const torch::Device& device) {
861 torch::Tensor lazy_input = CopyToDevice(input_copy, device);
862 lazy_input.ge_(other);
863 AllClose(lazy_input, input);
864 });
865}
866
867TEST_F(LazyOpsTest, TestLeScalar) {
868 torch::Tensor input = torch::ones({2, 3});
869 torch::Scalar other(float(1));
870 torch::Tensor result = torch::le(input, other);
871 ForEachDevice([&](const torch::Device& device) {
872 torch::Tensor lazy_input = CopyToDevice(input, device);
873 torch::Tensor lazy_result = torch::le(lazy_input, other);
874 AllEqual(result, lazy_result);
875 });
876}
877
878TEST_F(LazyOpsTest, TestLeScalarInplace) {
879 torch::Tensor input = torch::arange(
880 -1.,
881 1.5,
882 0.5,
883 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
884 torch::Scalar other(float(0));
885 torch::Tensor input_copy = input.clone();
886 input.le_(other);
887 ForEachDevice([&](const torch::Device& device) {
888 torch::Tensor lazy_input = CopyToDevice(input_copy, device);
889 lazy_input.le_(other);
890 AllClose(lazy_input, input);
891 });
892}
893
894TEST_F(LazyOpsTest, TestGtScalar) {
895 torch::Tensor input = torch::ones({2, 3});
896 torch::Scalar other(float(0.5));
897 torch::Tensor result = torch::gt(input, other);
898 ForEachDevice([&](const torch::Device& device) {
899 torch::Tensor lazy_input = CopyToDevice(input, device);
900 torch::Tensor lazy_result = torch::gt(lazy_input, other);
901 AllEqual(result, lazy_result);
902 });
903}
904
905TEST_F(LazyOpsTest, TestGtScalarInplace) {
906 torch::Tensor input = torch::arange(
907 -1.,
908 1.5,
909 0.5,
910 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
911 torch::Scalar other(float(0));
912 torch::Tensor input_copy = input.clone();
913 input.gt_(other);
914 ForEachDevice([&](const torch::Device& device) {
915 torch::Tensor lazy_input = CopyToDevice(input_copy, device);
916 lazy_input.gt_(other);
917 AllClose(lazy_input, input);
918 });
919}
920
921TEST_F(LazyOpsTest, TestLtScalar) {
922 torch::Tensor input = torch::ones({2, 3});
923 torch::Scalar other(float(1.5));
924 torch::Tensor result = torch::lt(input, other);
925 ForEachDevice([&](const torch::Device& device) {
926 torch::Tensor lazy_input = CopyToDevice(input, device);
927 torch::Tensor lazy_result = torch::lt(lazy_input, other);
928 AllEqual(result, lazy_result);
929 });
930}
931
932TEST_F(LazyOpsTest, TestLtScalarInplace) {
933 torch::Tensor input = torch::arange(
934 -1.,
935 1.5,
936 0.5,
937 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
938 torch::Scalar other(float(0));
939 torch::Tensor input_copy = input.clone();
940 input.lt_(other);
941 ForEachDevice([&](const torch::Device& device) {
942 torch::Tensor lazy_input = CopyToDevice(input_copy, device);
943 lazy_input.lt_(other);
944 AllClose(lazy_input, input);
945 });
946}
947
948TEST_F(LazyOpsTest, TestIntegerAdd) {
949 std::vector<torch::ScalarType> types(
950 {torch::kByte, torch::kChar, torch::kShort, torch::kInt, torch::kLong});
951
952 ForEachDevice([&](const torch::Device& device) {
953 for (auto type : types) {
954 torch::Tensor a =
955 torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
956 torch::Tensor b =
957 torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
958 torch::Scalar one =
959 isIntegralType(type, false) ? torch::Scalar(1) : torch::Scalar(1.0);
960 torch::Tensor c = torch::add(b, one);
961
962 torch::Tensor lazy_a = CopyToDevice(a, device);
963 torch::Tensor lazy_b = CopyToDevice(b, device);
964 torch::Tensor lazy_c = torch::add(lazy_b, one);
965
966 AllEqual(c, lazy_c);
967 }
968 });
969}
970
971TEST_F(LazyOpsTest, TestSVD) {
972 static const int dims[] = {4, 7};
973 for (auto m : dims) {
974 for (auto n : dims) {
975 torch::Tensor a = torch::rand(
976 {m, n}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
977 auto b = torch::svd(a, /*some=*/true, /*compute_uv=*/true);
978 ForEachDevice([&](const torch::Device& device) {
979 torch::Tensor lazy_a = CopyToDevice(a, device);
980 auto lazy_b = torch::svd(lazy_a, /*some=*/true, /*compute_uv=*/true);
981 // The U and V matrices might have different sign for column vectors, so
982 // cannot be compared if not by absolute value.
983 AllClose(
984 std::get<0>(b).abs(),
985 std::get<0>(lazy_b).abs(),
986 /*rtol=*/1e-3,
987 /*atol=*/1e-4);
988 torch::Tensor diag = std::get<1>(b);
989 torch::Tensor lazy_diag = std::get<1>(lazy_b);
990 ASSERT_EQ(diag.sizes(), lazy_diag.sizes());
991 AllClose(
992 diag,
993 lazy_diag,
994 /*rtol=*/1e-3,
995 /*atol=*/1e-4);
996 AllClose(
997 std::get<2>(b).abs(),
998 std::get<2>(lazy_b).abs(),
999 /*rtol=*/1e-3,
1000 /*atol=*/1e-4);
1001 });
1002 }
1003 }
1004}
1005
1006TEST_F(LazyOpsTest, TestQR) {
1007 static const int dims[] = {4, 7};
1008 for (auto m : dims) {
1009 for (auto n : dims) {
1010 torch::Tensor a = torch::rand(
1011 {m, n}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1012 auto b = torch::qr(a);
1013 ForEachDevice([&](const torch::Device& device) {
1014 torch::Tensor lazy_a = CopyToDevice(a, device);
1015 auto lazy_b = torch::qr(lazy_a);
1016 AllClose(
1017 std::get<0>(b).abs(),
1018 std::get<0>(lazy_b).abs(),
1019 /*rtol=*/1e-3,
1020 /*atol=*/1e-4);
1021 AllClose(
1022 std::get<1>(b).abs(),
1023 std::get<1>(lazy_b).abs(),
1024 /*rtol=*/1e-3,
1025 /*atol=*/1e-4);
1026 });
1027 }
1028 }
1029}
1030
1031TEST_F(LazyOpsTest, TestCholesky) {
1032 static const int dims[] = {4, 7};
1033 for (auto m : dims) {
1034 for (bool upper : {true, false}) {
1035 torch::Tensor a = torch::rand(
1036 {3, m, m},
1037 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1038 torch::Tensor pd_a =
1039 torch::matmul(a, torch::transpose(a, 1, 2)) +
1040 torch::eye(
1041 m, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1042 auto b = torch::cholesky(pd_a, upper);
1043 ForEachDevice([&](const torch::Device& device) {
1044 torch::Tensor lazy_a = CopyToDevice(pd_a, device);
1045 auto lazy_b = torch::cholesky(lazy_a, upper);
1046 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-4);
1047 });
1048 }
1049 }
1050}
1051
1052TEST_F(LazyOpsTest, TestLogDet) {
1053 static const int dims[] = {4, 7};
1054 for (auto m : dims) {
1055 torch::Tensor a = torch::rand(
1056 {3, m, m}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1057 torch::Tensor pd_a = torch::matmul(a, torch::transpose(a, 1, 2)) +
1058 torch::eye(m,
1059 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1060 torch::Tensor b = torch::logdet(pd_a);
1061 ForEachDevice([&](const torch::Device& device) {
1062 torch::Tensor lazy_a = CopyToDevice(pd_a, device);
1063 torch::Tensor lazy_b = torch::logdet(lazy_a);
1064 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-4);
1065 });
1066 }
1067}
1068
1069TEST_F(LazyOpsTest, TestTriangularSolve) {
1070 static const int dims[] = {4, 7};
1071 for (bool batched_a : {true, false}) {
1072 for (bool batched_b : {true, false}) {
1073 for (auto m : dims) {
1074 for (auto n : dims) {
1075 for (bool upper : {true, false}) {
1076 for (bool transpose : {true, false}) {
1077 for (bool unitriangular : {true, false}) {
1078 torch::Tensor a = torch::randn(
1079 {m, m},
1080 torch::TensorOptions(torch::kFloat)
1081 .device(DefaultDevice()));
1082 torch::Tensor b = torch::randn(
1083 {m, n},
1084 torch::TensorOptions(torch::kFloat)
1085 .device(DefaultDevice()));
1086 a = batched_a ? a.expand({3, m, m}).clone() : a;
1087 b = batched_b ? b.expand({3, m, n}).clone() : b;
1088 auto result = torch::triangular_solve(
1089 b,
1090 a,
1091 /*upper=*/upper,
1092 /*transpose=*/transpose,
1093 /*unitriangular=*/unitriangular);
1094 ForEachDevice([&](const torch::Device& device) {
1095 torch::Tensor lazy_a = CopyToDevice(a, device);
1096 torch::Tensor lazy_b = CopyToDevice(b, device);
1097 auto lazy_result = torch::triangular_solve(
1098 lazy_b,
1099 lazy_a,
1100 /*upper=*/upper,
1101 /*transpose=*/transpose,
1102 /*unitriangular=*/unitriangular);
1103 AllClose(
1104 std::get<0>(result),
1105 std::get<0>(lazy_result),
1106 /*rtol=*/1e-3,
1107 /*atol=*/1e-4);
1108 AllClose(
1109 std::get<1>(result),
1110 std::get<1>(lazy_result),
1111 /*rtol=*/1e-3,
1112 /*atol=*/1e-4);
1113 });
1114 }
1115 }
1116 }
1117 }
1118 }
1119 }
1120 }
1121}
1122
1123TEST_F(LazyOpsTest, TestKthValue) {
1124 torch::Tensor a = torch::rand(
1125 {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1126 for (int k = 1; k <= 3; ++k) {
1127 int rank = a.dim();
1128 for (int dim = -rank; dim < rank; ++dim) {
1129 for (bool keepdim : {false, true}) {
1130 auto b = torch::kthvalue(a, k, dim, keepdim);
1131 ForEachDevice([&](const torch::Device& device) {
1132 torch::Tensor lazy_a = CopyToDevice(a, device);
1133 auto lazy_b = torch::kthvalue(lazy_a, k, dim, keepdim);
1134 AllClose(std::get<0>(b), std::get<0>(lazy_b));
1135 AllEqual(std::get<1>(b), std::get<1>(lazy_b));
1136 });
1137 }
1138 }
1139 }
1140}
1141
1142TEST_F(LazyOpsTest, TestTopK) {
1143 torch::Tensor a = torch::rand(
1144 {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1145 for (int k = 1; k <= 3; ++k) {
1146 int rank = a.dim();
1147 for (int dim = -rank; dim < rank; ++dim) {
1148 for (bool largest : {false, true}) {
1149 auto b = torch::topk(a, k, dim, largest, /*sorted=*/true);
1150 ForEachDevice([&](const torch::Device& device) {
1151 torch::Tensor lazy_a = CopyToDevice(a, device);
1152 auto lazy_b = torch::topk(lazy_a, k, dim, largest, /*sorted=*/true);
1153 AllClose(std::get<0>(b), std::get<0>(lazy_b));
1154 AllEqual(std::get<1>(b), std::get<1>(lazy_b));
1155 });
1156 }
1157 }
1158 }
1159}
1160
1161TEST_F(LazyOpsTest, TestSort) {
1162 torch::Tensor a = torch::rand(
1163 {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1164 for (int k = 1; k <= 3; ++k) {
1165 for (int dim = 0; dim < 3; ++dim) {
1166 for (bool descending : {false, true}) {
1167 auto b = torch::sort(a, dim, descending);
1168 ForEachDevice([&](const torch::Device& device) {
1169 torch::Tensor lazy_a = CopyToDevice(a, device);
1170 auto lazy_b = torch::sort(lazy_a, dim, descending);
1171 AllClose(std::get<0>(b), std::get<0>(lazy_b));
1172 AllEqual(std::get<1>(b), std::get<1>(lazy_b));
1173 });
1174 }
1175 }
1176 }
1177}
1178
1179TEST_F(LazyOpsTest, TestSortDescWithMinValue) {
1180 std::vector<int8_t> values{-128, 100};
1181 torch::Tensor input =
1182 torch::tensor(values, torch::TensorOptions(torch::kChar));
1183 auto output = torch::sort(input, /*dim=*/0, /*descending=*/true);
1184 ForEachDevice([&](const torch::Device& device) {
1185 torch::Tensor lazy_input = CopyToDevice(input, device);
1186 auto lazy_output = torch::sort(lazy_input, /*dim=*/0, /*descending=*/true);
1187 AllEqual(std::get<0>(output), std::get<0>(lazy_output));
1188 AllEqual(std::get<1>(output), std::get<1>(lazy_output));
1189 });
1190}
1191
1192TEST_F(LazyOpsTest, TestArgSort) {
1193 torch::Tensor a = torch::rand(
1194 {4, 5, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1195 for (int k = 1; k <= 3; ++k) {
1196 for (int dim = 0; dim < 3; ++dim) {
1197 for (bool descending : {false, true}) {
1198 torch::Tensor b = torch::argsort(a, dim, descending);
1199 ForEachDevice([&](const torch::Device& device) {
1200 torch::Tensor lazy_a = CopyToDevice(a, device);
1201 torch::Tensor lazy_b = torch::argsort(lazy_a, dim, descending);
1202 AllEqual(b, lazy_b);
1203 });
1204 }
1205 }
1206 }
1207}
1208
1209TEST_F(LazyOpsTest, TestMin) {
1210 torch::Tensor a = torch::rand(
1211 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1212 torch::Tensor b = torch::rand(
1213 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1214 torch::Tensor c = torch::min(a, b);
1215 ForEachDevice([&](const torch::Device& device) {
1216 torch::Tensor lazy_a = CopyToDevice(a, device);
1217 torch::Tensor lazy_b = CopyToDevice(b, device);
1218 torch::Tensor lazy_c = torch::min(lazy_a, lazy_b);
1219 AllClose(c, lazy_c);
1220 });
1221}
1222
1223TEST_F(LazyOpsTest, TestMax) {
1224 torch::Tensor a = torch::rand(
1225 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1226 torch::Tensor b = torch::rand(
1227 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1228 torch::Tensor c = torch::max(a, b);
1229 ForEachDevice([&](const torch::Device& device) {
1230 torch::Tensor lazy_a = CopyToDevice(a, device);
1231 torch::Tensor lazy_b = CopyToDevice(b, device);
1232 torch::Tensor lazy_c = torch::max(lazy_a, lazy_b);
1233 AllClose(c, lazy_c);
1234 });
1235}
1236
1237TEST_F(LazyOpsTest, TestUnaryMin) {
1238 torch::Tensor input = torch::rand(
1239 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1240 torch::Tensor output = torch::min(input);
1241 ForEachDevice([&](const torch::Device& device) {
1242 torch::Tensor lazy_input = CopyToDevice(input, device);
1243 torch::Tensor lazy_output = torch::min(lazy_input);
1244 AllClose(output, lazy_output);
1245 });
1246}
1247
1248TEST_F(LazyOpsTest, TestUnaryMax) {
1249 torch::Tensor input = torch::rand(
1250 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1251 torch::Tensor output = torch::max(input);
1252 ForEachDevice([&](const torch::Device& device) {
1253 torch::Tensor lazy_input = CopyToDevice(input, device);
1254 torch::Tensor lazy_output = torch::max(lazy_input);
1255 AllClose(output, lazy_output);
1256 });
1257}
1258
1259TEST_F(LazyOpsTest, TestAll) {
1260 for (torch::ScalarType scalar_type :
1261 {torch::kFloat,
1262 torch::kByte,
1263 torch::kChar,
1264 torch::kShort,
1265 torch::kInt,
1266 torch::kLong}) {
1267 torch::Tensor a = isFloatingType(scalar_type)
1268 ? torch::rand(
1269 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
1270 : torch::randint(
1271 100,
1272 {3, 4},
1273 torch::TensorOptions(scalar_type).device(DefaultDevice()));
1274 torch::Tensor b = torch::all(a);
1275 ForEachDevice([&](const torch::Device& device) {
1276 torch::Tensor lazy_a = CopyToDevice(a, device);
1277 torch::Tensor lazy_b = torch::all(lazy_a);
1278 EqualValues(b, lazy_b);
1279 });
1280 }
1281}
1282
1283TEST_F(LazyOpsTest, TestAllDim) {
1284 torch::Tensor a = torch::randint(
1285 0,
1286 5,
1287 {2, 3, 4},
1288 torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1289 int rank = a.dim();
1290 for (int dim = -rank; dim < rank; ++dim) {
1291 torch::Tensor b = torch::all(a, dim, /*keepdim=*/false);
1292 ForEachDevice([&](const torch::Device& device) {
1293 torch::Tensor lazy_a = CopyToDevice(a, device);
1294 torch::Tensor lazy_b = torch::all(lazy_a, dim, /*keepdim=*/false);
1295 EqualValues(b, lazy_b);
1296 });
1297 }
1298}
1299
1300TEST_F(LazyOpsTest, TestAllDimKeep) {
1301 torch::Tensor a = torch::randint(
1302 0,
1303 5,
1304 {2, 3, 4},
1305 torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1306 int rank = a.dim();
1307 for (int dim = -rank; dim < rank; ++dim) {
1308 torch::Tensor b = torch::all(a, dim, /*keepdim=*/true);
1309 ForEachDevice([&](const torch::Device& device) {
1310 torch::Tensor lazy_a = CopyToDevice(a, device);
1311 torch::Tensor lazy_b = torch::all(lazy_a, dim, /*keepdim=*/true);
1312 EqualValues(b, lazy_b);
1313 });
1314 }
1315}
1316
1317TEST_F(LazyOpsTest, TestAmax) {
1318 torch::Tensor input = torch::rand(
1319 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1320 int rank = input.dim();
1321 for (bool keepdim : {false, true}) {
1322 for (int dim = -rank; dim < rank; ++dim) {
1323 torch::Tensor values = torch::amax(input, {dim}, /*keepdim=*/keepdim);
1324 ForEachDevice([&](const torch::Device& device) {
1325 torch::Tensor lazy_input = CopyToDevice(input, device);
1326 torch::Tensor lazy_values =
1327 torch::amax(lazy_input, {dim}, /*keepdim=*/keepdim);
1328 AllClose(values, lazy_values);
1329 });
1330 }
1331 for (int dim1 = -rank; dim1 < rank; ++dim1) {
1332 for (int dim2 = -rank; dim2 < rank; ++dim2) {
1333 if ((dim1 == dim2) || (dim1 == rank + dim2) || (dim2 == rank + dim1))
1334 continue;
1335 torch::Tensor values =
1336 torch::amax(input, {dim1, dim2}, /*keepdim=*/keepdim);
1337 ForEachDevice([&](const torch::Device& device) {
1338 torch::Tensor lazy_input = CopyToDevice(input, device);
1339 torch::Tensor lazy_values =
1340 torch::amax(lazy_input, {dim1, dim2}, /*keepdim=*/keepdim);
1341 AllClose(values, lazy_values);
1342 });
1343 }
1344 }
1345 }
1346 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
1347 ExpectCounterChanged("xla::amax", GetIgnoredCounters());
1348}
1349
1350TEST_F(LazyOpsTest, TestAmin) {
1351 torch::Tensor input = torch::rand(
1352 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1353 int rank = input.dim();
1354 for (bool keepdim : {false, true}) {
1355 for (int dim = -rank; dim < rank; ++dim) {
1356 torch::Tensor values = torch::amin(input, {dim}, /*keepdim=*/keepdim);
1357 ForEachDevice([&](const torch::Device& device) {
1358 torch::Tensor lazy_input = CopyToDevice(input, device);
1359 torch::Tensor lazy_values =
1360 torch::amin(lazy_input, {dim}, /*keepdim=*/keepdim);
1361 AllClose(values, lazy_values);
1362 });
1363 }
1364 for (int dim1 = -rank; dim1 < rank; ++dim1) {
1365 for (int dim2 = -rank; dim2 < rank; ++dim2) {
1366 if ((dim1 == dim2) || (dim1 == rank + dim2) || (dim2 == rank + dim1))
1367 continue;
1368 torch::Tensor values =
1369 torch::amin(input, {dim1, dim2}, /*keepdim=*/keepdim);
1370 ForEachDevice([&](const torch::Device& device) {
1371 torch::Tensor lazy_input = CopyToDevice(input, device);
1372 torch::Tensor lazy_values =
1373 torch::amin(lazy_input, {dim1, dim2}, /*keepdim=*/keepdim);
1374 AllClose(values, lazy_values);
1375 });
1376 }
1377 }
1378 }
1379 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
1380 ExpectCounterChanged("xla::amin", GetIgnoredCounters());
1381}
1382
1383TEST_F(LazyOpsTest, TestAny) {
1384 for (torch::ScalarType scalar_type :
1385 {torch::kFloat,
1386 torch::kByte,
1387 torch::kChar,
1388 torch::kShort,
1389 torch::kInt,
1390 torch::kLong}) {
1391 torch::Tensor a = isFloatingType(scalar_type)
1392 ? torch::rand(
1393 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
1394 : torch::randint(
1395 100,
1396 {3, 4},
1397 torch::TensorOptions(scalar_type).device(DefaultDevice()));
1398 torch::Tensor b = torch::any(a);
1399 ForEachDevice([&](const torch::Device& device) {
1400 torch::Tensor lazy_a = CopyToDevice(a, device);
1401 torch::Tensor lazy_b = torch::any(lazy_a);
1402 EqualValues(b, lazy_b);
1403 });
1404 }
1405}
1406
1407TEST_F(LazyOpsTest, TestAnyDim) {
1408 torch::Tensor a = torch::randint(
1409 0,
1410 5,
1411 {2, 3, 4},
1412 torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1413 int rank = a.dim();
1414 for (int dim = -rank; dim < rank; ++dim) {
1415 torch::Tensor b = torch::any(a, dim, /*keepdim=*/false);
1416 ForEachDevice([&](const torch::Device& device) {
1417 torch::Tensor lazy_a = CopyToDevice(a, device);
1418 torch::Tensor lazy_b = torch::any(lazy_a, dim, /*keepdim=*/false);
1419 EqualValues(b, lazy_b);
1420 });
1421 }
1422}
1423
1424TEST_F(LazyOpsTest, TestAnyDimKeep) {
1425 torch::Tensor a = torch::randint(
1426 0,
1427 5,
1428 {2, 3, 4},
1429 torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1430 int rank = a.dim();
1431 for (int dim = -rank; dim < rank; ++dim) {
1432 torch::Tensor b = torch::any(a, dim, /*keepdim=*/true);
1433 ForEachDevice([&](const torch::Device& device) {
1434 torch::Tensor lazy_a = CopyToDevice(a, device);
1435 torch::Tensor lazy_b = torch::any(lazy_a, dim, /*keepdim=*/true);
1436 EqualValues(b, lazy_b);
1437 });
1438 }
1439}
1440
1441TEST_F(LazyOpsTest, TestMean) {
1442 torch::Tensor a = torch::rand(
1443 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1444 torch::Tensor b = torch::mean(a);
1445 ForEachDevice([&](const torch::Device& device) {
1446 torch::Tensor lazy_a = CopyToDevice(a, device);
1447 torch::Tensor lazy_b = torch::mean(lazy_a);
1448 ASSERT_EQ(b.sizes(), lazy_b.sizes());
1449 AllClose(b, lazy_b);
1450 });
1451}
1452
1453TEST_F(LazyOpsTest, TestMeanCast) {
1454 torch::Tensor a = torch::rand(
1455 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1456 torch::Tensor b = torch::mean(a, torch::kDouble);
1457 ForEachDevice([&](const torch::Device& device) {
1458 torch::Tensor lazy_a = CopyToDevice(a, device);
1459 torch::Tensor lazy_b = torch::mean(lazy_a, torch::kDouble);
1460 AllClose(b, lazy_b);
1461 });
1462}
1463
1464TEST_F(LazyOpsTest, TestMeanInDim) {
1465 torch::Tensor a = torch::rand(
1466 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1467 int rank = a.dim();
1468 for (int dim = -rank; dim < rank; ++dim) {
1469 torch::Tensor b = torch::mean(a, {dim});
1470 ForEachDevice([&](const torch::Device& device) {
1471 torch::Tensor lazy_a = CopyToDevice(a, device);
1472 torch::Tensor lazy_b = torch::mean(lazy_a, {dim});
1473 AllClose(b, lazy_b);
1474 });
1475 }
1476}
1477
1478TEST_F(LazyOpsTest, TestMeanInDims) {
1479 torch::Tensor a = torch::rand(
1480 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1481 for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1482 torch::Tensor b = torch::mean(a, dims);
1483 ForEachDevice([&](const torch::Device& device) {
1484 torch::Tensor lazy_a = CopyToDevice(a, device);
1485 torch::Tensor lazy_b = torch::mean(lazy_a, dims);
1486 AllClose(b, lazy_b);
1487 });
1488 }
1489}
1490
1491TEST_F(LazyOpsTest, TestMeanInDimsKeepCast) {
1492 torch::Tensor a = torch::rand(
1493 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1494 for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1495 torch::Tensor b = torch::mean(a, dims, true, torch::kDouble);
1496 ForEachDevice([&](const torch::Device& device) {
1497 torch::Tensor lazy_a = CopyToDevice(a, device);
1498 torch::Tensor lazy_b = torch::mean(lazy_a, dims, true, torch::kDouble);
1499 AllClose(b, lazy_b);
1500 });
1501 }
1502}
1503
1504TEST_F(LazyOpsTest, TestMeanInDimOut) {
1505 torch::Tensor a = torch::rand(
1506 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1507 int rank = a.dim();
1508 for (int dim = -rank; dim < rank; ++dim) {
1509 torch::Tensor b = torch::empty(
1510 {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1511 torch::mean_out(b, a, {dim});
1512 ForEachDevice([&](const torch::Device& device) {
1513 torch::Tensor lazy_a = CopyToDevice(a, device);
1514 torch::Tensor lazy_b = torch::empty({4, 4}, lazy_a.options());
1515 torch::mean_out(lazy_b, lazy_a, {dim});
1516 AllClose(b, lazy_b);
1517 });
1518 }
1519}
1520
1521TEST_F(LazyOpsTest, TestStd) {
1522 torch::Tensor a = torch::rand(
1523 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1524 for (auto unbiased : {true, false}) {
1525 torch::Tensor b = torch::std(a, unbiased);
1526 ForEachDevice([&](const torch::Device& device) {
1527 torch::Tensor lazy_a = CopyToDevice(a, device);
1528 torch::Tensor lazy_b = torch::std(lazy_a, unbiased);
1529 AllClose(b, lazy_b);
1530 });
1531 }
1532}
1533
1534TEST_F(LazyOpsTest, TestStdInDim) {
1535 torch::Tensor a = torch::rand(
1536 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1537 int rank = a.dim();
1538 for (auto unbiased : {true, false}) {
1539 for (auto keepdim : {true, false}) {
1540 for (int dim = -rank; dim < rank; ++dim) {
1541 torch::Tensor b = torch::std(a, {dim}, unbiased, keepdim);
1542 ForEachDevice([&](const torch::Device& device) {
1543 torch::Tensor lazy_a = CopyToDevice(a, device);
1544 torch::Tensor lazy_b = torch::std(lazy_a, {dim}, unbiased, keepdim);
1545 AllClose(b, lazy_b);
1546 });
1547 }
1548 }
1549 }
1550}
1551
1552TEST_F(LazyOpsTest, TestStdWithCorrection) {
1553 torch::Tensor a = torch::rand(
1554 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1555 // int rank = a.dim();
1556 c10::optional<int64_t> corrections[] = {1, 2, c10::nullopt};
1557 for (const auto& correction : corrections) {
1558 for (auto keepdim : {true, false}) {
1559 for (const auto& dim :
1560 std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1561 torch::Tensor b = torch::std(a, dim, correction, keepdim);
1562 ForEachDevice([&](const torch::Device& device) {
1563 torch::Tensor lazy_a = CopyToDevice(a, device);
1564 torch::Tensor lazy_b = torch::std(lazy_a, dim, correction, keepdim);
1565 AllClose(b, lazy_b);
1566 });
1567 }
1568 }
1569 }
1570}
1571
1572TEST_F(LazyOpsTest, TestStdMeanWithCorrection) {
1573 torch::Tensor a = torch::rand(
1574 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1575 // int rank = a.dim();
1576 c10::optional<int64_t> corrections[] = {1, 2, c10::nullopt};
1577 for (const auto& correction : corrections) {
1578 for (auto keepdim : {true, false}) {
1579 for (const auto& dim :
1580 std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1581 auto b = torch::std_mean(a, dim, correction, keepdim);
1582 ForEachDevice([&](const torch::Device& device) {
1583 torch::Tensor lazy_a = CopyToDevice(a, device);
1584 auto lazy_b = torch::std_mean(lazy_a, dim, correction, keepdim);
1585 AllClose(std::get<0>(b), std::get<0>(lazy_b));
1586 AllClose(std::get<1>(b), std::get<1>(lazy_b));
1587 });
1588 }
1589 }
1590 }
1591}
1592
1593TEST_F(LazyOpsTest, TestSum) {
1594 torch::Tensor a = torch::rand(
1595 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1596 torch::Tensor b = torch::sum(a);
1597 ForEachDevice([&](const torch::Device& device) {
1598 torch::Tensor lazy_a = CopyToDevice(a, device);
1599 torch::Tensor lazy_b = torch::sum(lazy_a);
1600 AllClose(b, lazy_b);
1601 });
1602}
1603
1604TEST_F(LazyOpsTest, TestSumCast) {
1605 torch::Tensor a = torch::rand(
1606 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1607 torch::Tensor b = torch::sum(a, torch::kDouble);
1608 ForEachDevice([&](const torch::Device& device) {
1609 torch::Tensor lazy_a = CopyToDevice(a, device);
1610 torch::Tensor lazy_b = torch::sum(lazy_a, torch::kDouble);
1611 AllClose(b, lazy_b);
1612 });
1613}
1614
1615TEST_F(LazyOpsTest, TestSumU8) {
1616 torch::Tensor a = torch::ones(
1617 {256}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
1618 torch::Tensor b = torch::sum(a);
1619 ForEachDevice([&](const torch::Device& device) {
1620 torch::Tensor lazy_a = CopyToDevice(a, device);
1621 torch::Tensor lazy_b = torch::sum(lazy_a);
1622 AllEqual(b, lazy_b);
1623 });
1624}
1625
1626TEST_F(LazyOpsTest, TestSumInDim) {
1627 torch::Tensor a = torch::rand(
1628 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1629 int rank = a.dim();
1630 for (int dim = -rank; dim < rank; ++dim) {
1631 torch::Tensor b = torch::sum(a, {dim});
1632 ForEachDevice([&](const torch::Device& device) {
1633 torch::Tensor lazy_a = CopyToDevice(a, device);
1634 torch::Tensor lazy_b = torch::sum(lazy_a, {dim});
1635 AllClose(b, lazy_b);
1636 });
1637 }
1638}
1639
1640TEST_F(LazyOpsTest, TestSumInDims) {
1641 torch::Tensor a = torch::rand(
1642 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1643 for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1644 torch::Tensor b = torch::sum(a, dims);
1645 ForEachDevice([&](const torch::Device& device) {
1646 torch::Tensor lazy_a = CopyToDevice(a, device);
1647 torch::Tensor lazy_b = torch::sum(lazy_a, dims);
1648 AllClose(b, lazy_b);
1649 });
1650 }
1651}
1652
1653TEST_F(LazyOpsTest, TestSumInDimsKeep) {
1654 torch::Tensor a = torch::rand(
1655 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1656 for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1657 torch::Tensor b = torch::sum(a, dims, /*keepdim=*/true);
1658 ForEachDevice([&](const torch::Device& device) {
1659 torch::Tensor lazy_a = CopyToDevice(a, device);
1660 torch::Tensor lazy_b = torch::sum(lazy_a, dims, /*keepdim=*/true);
1661 AllClose(b, lazy_b);
1662 });
1663 }
1664}
1665
1666TEST_F(LazyOpsTest, TestSumInDimsKeepCast) {
1667 torch::Tensor a = torch::rand(
1668 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1669 for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1670 torch::Tensor b = torch::sum(a, dims, /*keepdim=*/true, torch::kDouble);
1671 ForEachDevice([&](const torch::Device& device) {
1672 torch::Tensor lazy_a = CopyToDevice(a, device);
1673 torch::Tensor lazy_b =
1674 torch::sum(lazy_a, dims, /*keepdim=*/true, torch::kDouble);
1675 AllClose(b, lazy_b);
1676 });
1677 }
1678}
1679
1680TEST_F(LazyOpsTest, TestVar) {
1681 torch::Tensor a = torch::rand(
1682 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1683 for (bool unbiased : {true, false}) {
1684 torch::Tensor b = torch::var(a, unbiased);
1685 ForEachDevice([&](const torch::Device& device) {
1686 torch::Tensor lazy_a = CopyToDevice(a, device);
1687 torch::Tensor lazy_b = torch::var(lazy_a, unbiased);
1688 AllClose(b, lazy_b);
1689 });
1690 }
1691}
1692
1693TEST_F(LazyOpsTest, TestVarWithDim) {
1694 torch::Tensor a = torch::rand(
1695 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1696 for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1697 for (bool keepDim : {true, false}) {
1698 for (bool unbiased : {true, false}) {
1699 torch::Tensor b = torch::var(a, dims, unbiased, keepDim);
1700 ForEachDevice([&](const torch::Device& device) {
1701 torch::Tensor lazy_a = CopyToDevice(a, device);
1702 torch::Tensor lazy_b = torch::var(lazy_a, dims, unbiased, keepDim);
1703 AllClose(b, lazy_b);
1704 });
1705 }
1706 }
1707 }
1708}
1709
1710TEST_F(LazyOpsTest, TestVarWithCorrection) {
1711 torch::Tensor a = torch::rand(
1712 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1713 c10::optional<int64_t> corrections[] = {1, 2, c10::nullopt};
1714 for (const auto& dim : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1715 for (bool keepDim : {true, false}) {
1716 for (const auto& correction : corrections) {
1717 torch::Tensor b = torch::var(a, dim, correction, keepDim);
1718 ForEachDevice([&](const torch::Device& device) {
1719 torch::Tensor lazy_a = CopyToDevice(a, device);
1720 torch::Tensor lazy_b = torch::var(lazy_a, dim, correction, keepDim);
1721 AllClose(b, lazy_b);
1722 });
1723 }
1724 }
1725 }
1726 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
1727 ExpectCounterChanged("lazy::var", GetIgnoredCounters());
1728}
1729
1730TEST_F(LazyOpsTest, TestVarMeanWithCorrection) {
1731 torch::Tensor a = torch::rand(
1732 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1733 c10::optional<int64_t> corrections[] = {1, 2, c10::nullopt};
1734 for (const auto& dim : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
1735 for (const auto& correction : corrections) {
1736 for (auto keepdim : {true, false}) {
1737 auto b = torch::var_mean(a, dim, correction, keepdim);
1738 ForEachDevice([&](const torch::Device& device) {
1739 torch::Tensor lazy_a = CopyToDevice(a, device);
1740 auto lazy_b = torch::var_mean(lazy_a, dim, correction, keepdim);
1741 AllClose(std::get<0>(b), std::get<0>(lazy_b));
1742 AllClose(std::get<1>(b), std::get<1>(lazy_b));
1743 });
1744 }
1745 }
1746 }
1747}
1748
1749TEST_F(LazyOpsTest, TestMaxInDim) {
1750 torch::Tensor input = torch::rand(
1751 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1752 int rank = input.dim();
1753 for (int dim = -rank; dim < rank; ++dim) {
1754 for (bool keepdim : {false, true}) {
1755 auto values_indices = torch::max(input, dim, /*keepdim=*/keepdim);
1756 ForEachDevice([&](const torch::Device& device) {
1757 torch::Tensor lazy_input = CopyToDevice(input, device);
1758 auto lazy_values_indices =
1759 torch::max(lazy_input, dim, /*keepdim=*/keepdim);
1760 AllClose(std::get<0>(values_indices), std::get<0>(lazy_values_indices));
1761 AllEqual(std::get<1>(values_indices), std::get<1>(lazy_values_indices));
1762 });
1763 }
1764 }
1765}
1766
1767TEST_F(LazyOpsTest, TestMinInDim) {
1768 torch::Tensor input = torch::rand(
1769 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1770 int rank = input.dim();
1771 for (int dim = -rank; dim < rank; ++dim) {
1772 for (bool keepdim : {false, true}) {
1773 auto values_indices = torch::min(input, dim, /*keepdim=*/keepdim);
1774 ForEachDevice([&](const torch::Device& device) {
1775 torch::Tensor lazy_input = CopyToDevice(input, device);
1776 auto lazy_values_indices =
1777 torch::min(lazy_input, dim, /*keepdim=*/keepdim);
1778 AllClose(std::get<0>(values_indices), std::get<0>(lazy_values_indices));
1779 AllEqual(std::get<1>(values_indices), std::get<1>(lazy_values_indices));
1780 });
1781 }
1782 }
1783}
1784
1785TEST_F(LazyOpsTest, TestNorm) {
1786 torch::Tensor a = torch::rand(
1787 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1788 torch::Tensor b = torch::norm(a);
1789 ForEachDevice([&](const torch::Device& device) {
1790 torch::Tensor lazy_a = CopyToDevice(a, device);
1791 torch::Tensor lazy_b = torch::norm(lazy_a);
1792 AllClose(b, lazy_b);
1793 });
1794}
1795
1796TEST_F(LazyOpsTest, TestNormInDim) {
1797 torch::Tensor a = torch::rand(
1798 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1799 for (int dim : {1, -2}) {
1800 torch::Tensor b = torch::norm(a, 2, {dim}, /*keepdim=*/false);
1801 ForEachDevice([&](const torch::Device& device) {
1802 torch::Tensor lazy_a = CopyToDevice(a, device);
1803 torch::Tensor lazy_b = torch::norm(lazy_a, 2, {dim}, /*keepdim=*/false);
1804 AllClose(b, lazy_b);
1805 });
1806 }
1807}
1808
1809TEST_F(LazyOpsTest, TestNormInDims) {
1810 torch::Tensor a = torch::rand(
1811 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1812 for (auto dims : std::vector<std::vector<int64_t>>{{1, 2}, {-2, -1}}) {
1813 torch::Tensor b = torch::norm(a, 2, dims, /*keepdim=*/false);
1814 ForEachDevice([&](const torch::Device& device) {
1815 torch::Tensor lazy_a = CopyToDevice(a, device);
1816 torch::Tensor lazy_b = torch::norm(lazy_a, 2, dims, /*keepdim=*/false);
1817 AllClose(b, lazy_b);
1818 });
1819 }
1820}
1821
1822TEST_F(LazyOpsTest, TestNormInDimsKeep) {
1823 torch::Tensor a = torch::rand(
1824 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1825 for (auto dims : std::vector<std::vector<int64_t>>{{1, 2}, {-2, -1}}) {
1826 torch::Tensor b = torch::norm(a, 2, dims, /*keepdim=*/true);
1827 ForEachDevice([&](const torch::Device& device) {
1828 torch::Tensor lazy_a = CopyToDevice(a, device);
1829 torch::Tensor lazy_b = torch::norm(lazy_a, 2, dims, /*keepdim=*/true);
1830 AllClose(b, lazy_b);
1831 });
1832 }
1833}
1834
1835TEST_F(LazyOpsTest, TestNormalTwoTensor) {
1836 at::Tensor mean = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1837 at::Tensor std = at::ones({10, 10, 10}, at::dtype(at::kFloat));
1838 ForEachDevice([&](const torch::Device& device) {
1839 at::Tensor lazy_mean = CopyToDevice(mean, device);
1840 at::Tensor lazy_std = CopyToDevice(std, device);
1841 at::Tensor lazy_normal = at::normal(lazy_mean, lazy_std);
1842 double res_mean = lazy_normal.mean().item().toDouble();
1843 double res_std = lazy_normal.std().item().toDouble();
1844 EXPECT_GT(res_mean, -0.06);
1845 EXPECT_LT(res_mean, 0.06);
1846 EXPECT_GT(res_std, 0.94);
1847 EXPECT_LT(res_std, 1.06);
1848 });
1849}
1850
1851TEST_F(LazyOpsTest, TestNormalDoubleMean) {
1852 at::Tensor std = at::ones({10, 10, 10}, at::dtype(at::kFloat));
1853 ForEachDevice([&](const torch::Device& device) {
1854 at::Tensor lazy_std = CopyToDevice(std, device);
1855 at::Tensor lazy_normal = at::normal(0, lazy_std);
1856 double res_mean = lazy_normal.mean().item().toDouble();
1857 double res_std = lazy_normal.std().item().toDouble();
1858 EXPECT_GT(res_mean, -0.06);
1859 EXPECT_LT(res_mean, 0.06);
1860 EXPECT_GT(res_std, 0.94);
1861 EXPECT_LT(res_std, 1.06);
1862 });
1863}
1864
1865TEST_F(LazyOpsTest, TestNormalDoubleStd) {
1866 at::Tensor mean = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1867 ForEachDevice([&](const torch::Device& device) {
1868 at::Tensor lazy_mean = CopyToDevice(mean, device);
1869 at::Tensor lazy_normal = at::normal(lazy_mean, 1);
1870 double res_mean = lazy_normal.mean().item().toDouble();
1871 double res_std = lazy_normal.std().item().toDouble();
1872 EXPECT_GT(res_mean, -0.06);
1873 EXPECT_LT(res_mean, 0.06);
1874 EXPECT_GT(res_std, 0.94);
1875 EXPECT_LT(res_std, 1.06);
1876 });
1877}
1878
1879TEST_F(LazyOpsTest, TestNormalInPlace) {
1880 at::Tensor a = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1881 ForEachDevice([&](const torch::Device& device) {
1882 at::Tensor lazy_a = CopyToDevice(a, device);
1883 lazy_a.normal_(/*mean=*/0, /*std=*/1);
1884 double res_mean = lazy_a.mean().item().toDouble();
1885 double res_std = lazy_a.std().item().toDouble();
1886 EXPECT_GT(res_mean, -0.06);
1887 EXPECT_LT(res_mean, 0.06);
1888 EXPECT_GT(res_std, 0.94);
1889 EXPECT_LT(res_std, 1.06);
1890 });
1891}
1892
1893TEST_F(LazyOpsTest, TestUniformInPlace) {
1894 const double eps = 1e-3;
1895 at::Tensor a = at::zeros({10, 10, 10}, at::dtype(at::kFloat));
1896 ForEachDevice([&](const torch::Device& device) {
1897 at::Tensor lazy_a = CopyToDevice(a, device);
1898 lazy_a.uniform_(/*from=*/0, /*to=*/1);
1899 at::Tensor cpu_a = ToCpuTensor(lazy_a);
1900 double res_min = cpu_a.min().item().toDouble();
1901 double res_max = cpu_a.max().item().toDouble();
1902 EXPECT_GT(res_min, 0.0 - eps);
1903 EXPECT_LT(res_max, 1.0 + eps);
1904 });
1905}
1906
1907TEST_F(LazyOpsTest, TestRandomInPlace) {
1908 for (auto dtype :
1909 {torch::kFloat,
1910 torch::kDouble,
1911 torch::kByte,
1912 torch::kChar,
1913 torch::kShort,
1914 torch::kInt,
1915 torch::kLong}) {
1916 const double eps = 0.2;
1917 torch::Tensor a = torch::zeros({10, 10, 10}, torch::TensorOptions(dtype));
1918 ForEachDevice([&](const torch::Device& device) {
1919 torch::Tensor lazy_a = CopyToDevice(a, device);
1920 lazy_a.random_(/*from=*/0, /*to=*/10);
1921 double res_mean = lazy_a.sum().item().toDouble() / a.numel();
1922 double res_min = lazy_a.min().item().toDouble();
1923 double res_max = lazy_a.max().item().toDouble();
1924 EXPECT_GT(res_mean, 4.5 - eps);
1925 EXPECT_LT(res_mean, 4.5 + eps);
1926 EXPECT_EQ(res_min, 0.0);
1927 EXPECT_EQ(res_max, 9.0);
1928 });
1929 }
1930}
1931
1932TEST_F(LazyOpsTest, TestRandomInPlaceDefaultFrom) {
1933 for (auto dtype :
1934 {torch::kFloat,
1935 torch::kDouble,
1936 torch::kByte,
1937 torch::kChar,
1938 torch::kShort,
1939 torch::kInt,
1940 torch::kLong}) {
1941 const double eps = 0.2;
1942 torch::Tensor a = torch::zeros({10, 10, 10}, torch::TensorOptions(dtype));
1943 ForEachDevice([&](const torch::Device& device) {
1944 torch::Tensor lazy_a = CopyToDevice(a, device);
1945 lazy_a.random_(/*to=*/10);
1946 double res_mean = lazy_a.sum().item().toDouble() / a.numel();
1947 double res_min = lazy_a.min().item().toDouble();
1948 double res_max = lazy_a.max().item().toDouble();
1949 EXPECT_GT(res_mean, 4.5 - eps);
1950 EXPECT_LT(res_mean, 4.5 + eps);
1951 EXPECT_EQ(res_min, 0.0);
1952 EXPECT_EQ(res_max, 9.0);
1953 });
1954 }
1955}
1956
1957TEST_F(LazyOpsTest, TestRandomInPlaceDefault) {
1958 for (auto dtype :
1959 {torch::kFloat,
1960 torch::kDouble,
1961 torch::kByte,
1962 torch::kChar,
1963 torch::kShort,
1964 torch::kInt,
1965 torch::kLong}) {
1966 auto input = torch::zeros({10}, torch::TensorOptions(dtype));
1967 ForEachDevice([&](const torch::Device& device) {
1968 auto lazyInput = CopyToDevice(input, device);
1969 lazyInput.random_();
1970 auto output = ToCpuTensor(lazyInput);
1971 EXPECT_TRUE(torch::all(output.ne(input)).item<bool>());
1972 });
1973 }
1974}
1975
1976TEST_F(LazyOpsTest, TestNormGeneral) {
1977 torch::Tensor a = torch::randn(
1978 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1979 torch::Tensor b = torch::norm(a, 3.5);
1980 ForEachDevice([&](const torch::Device& device) {
1981 torch::Tensor lazy_a = CopyToDevice(a, device);
1982 torch::Tensor lazy_b = torch::norm(lazy_a, 3.5);
1983 AllClose(b, lazy_b);
1984 });
1985}
1986
1987TEST_F(LazyOpsTest, TestNormNuclear) {
1988 torch::Tensor a = torch::rand(
1989 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
1990 torch::Tensor b = torch::norm(a, 1);
1991 ForEachDevice([&](const torch::Device& device) {
1992 torch::Tensor lazy_a = CopyToDevice(a, device);
1993 torch::Tensor lazy_b = torch::norm(lazy_a, 1);
1994 AllClose(b, lazy_b);
1995 });
1996}
1997
1998TEST_F(LazyOpsTest, TestFrobeniusNormInDim) {
1999 torch::Tensor a = torch::rand(
2000 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2001 for (int dim : {1, -2}) {
2002 torch::Tensor b = torch::frobenius_norm(a, {dim}, /*keepdim=*/false);
2003 ForEachDevice([&](const torch::Device& device) {
2004 torch::Tensor lazy_a = CopyToDevice(a, device);
2005 torch::Tensor lazy_b =
2006 torch::frobenius_norm(lazy_a, {dim}, /*keepdim=*/false);
2007 AllClose(b, lazy_b);
2008 });
2009 }
2010}
2011
2012TEST_F(LazyOpsTest, TestFrobeniusNormInDims) {
2013 torch::Tensor a = torch::rand(
2014 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2015 for (auto dims : std::vector<std::vector<int64_t>>{{1, 2}, {-2, -1}}) {
2016 torch::Tensor b = torch::frobenius_norm(a, dims, /*keepdim=*/false);
2017 ForEachDevice([&](const torch::Device& device) {
2018 torch::Tensor lazy_a = CopyToDevice(a, device);
2019 torch::Tensor lazy_b =
2020 torch::frobenius_norm(lazy_a, dims, /*keepdim=*/false);
2021 AllClose(b, lazy_b);
2022 });
2023 }
2024}
2025
2026TEST_F(LazyOpsTest, TestGroupNorm) {
2027 int num_channels = 6;
2028 torch::Tensor input = torch::rand(
2029 {20, num_channels, 10, 10},
2030 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2031 torch::Tensor weight = torch::rand(
2032 {num_channels},
2033 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2034 torch::Tensor bias = torch::rand(
2035 {num_channels},
2036 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2037 double eps = 1e-05;
2038 for (int num_groups : {3, 6, 1}) {
2039 torch::Tensor output = torch::group_norm(
2040 input,
2041 num_groups,
2042 weight,
2043 bias,
2044 eps,
2045 /*cudnn_enabled=*/false);
2046 ForEachDevice([&](const torch::Device& device) {
2047 torch::Tensor lazy_input = CopyToDevice(input, device);
2048 torch::Tensor lazy_weight = CopyToDevice(weight, device);
2049 torch::Tensor lazy_bias = CopyToDevice(bias, device);
2050 torch::Tensor lazy_output = torch::group_norm(
2051 lazy_input,
2052 num_groups,
2053 lazy_weight,
2054 lazy_bias,
2055 eps,
2056 /*cudnn_enabled=*/false);
2057 AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
2058 });
2059 }
2060}
2061
2062TEST_F(LazyOpsTest, TestGroupNormBackward) {
2063 int num_channels = 6;
2064 torch::Tensor input = torch::rand(
2065 {2, num_channels, 5, 5},
2066 torch::TensorOptions(torch::kFloat)
2067 .device(DefaultDevice())
2068 .requires_grad(true));
2069 torch::Tensor weight = torch::rand(
2070 {num_channels},
2071 torch::TensorOptions(torch::kFloat)
2072 .device(DefaultDevice())
2073 .requires_grad(true));
2074 torch::Tensor bias = torch::rand(
2075 {num_channels},
2076 torch::TensorOptions(torch::kFloat)
2077 .device(DefaultDevice())
2078 .requires_grad(true));
2079 double eps = 1e-05;
2080 for (bool undef_weight : {true, false}) {
2081 for (int num_groups : {3, 6, 1}) {
2082 auto testfn =
2083 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
2084 return torch::group_norm(
2085 /*input=*/inputs[0],
2086 num_groups,
2087 inputs[1],
2088 inputs[2],
2089 /*eps=*/eps,
2090 /*cudnn_enabled=*/false);
2091 };
2092 torch::Tensor undef;
2093 ForEachDevice([&](const torch::Device& device) {
2094 TestBackward(
2095 {input, undef_weight ? undef : weight, undef_weight ? undef : bias},
2096 device,
2097 testfn,
2098 /*rtol=*/1e-3,
2099 /*atol=*/1e-3,
2100 /*derivative_level=*/2);
2101 });
2102 }
2103 }
2104}
2105
2106TEST_F(LazyOpsTest, TestInstanceNorm) {
2107 int batch = 5;
2108 int num_channels = 20;
2109 torch::Tensor input = torch::rand(
2110 {batch, num_channels, 10, 10},
2111 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2112 torch::Tensor weight = torch::rand(
2113 {num_channels},
2114 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2115 torch::Tensor bias = torch::rand(
2116 {num_channels},
2117 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2118 torch::Tensor running_mean = torch::zeros(
2119 {num_channels},
2120 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2121 torch::Tensor running_var = torch::ones(
2122 {num_channels},
2123 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2124 double momentum = 0.1;
2125 double eps = 1e-05;
2126 torch::Tensor output = torch::instance_norm(
2127 input,
2128 weight,
2129 bias,
2130 running_mean,
2131 running_var,
2132 /*use_input_stats=*/true,
2133 momentum,
2134 eps,
2135 /*cudnn_enabled=*/false);
2136 ForEachDevice([&](const torch::Device& device) {
2137 torch::Tensor lazy_input = CopyToDevice(input, device);
2138 torch::Tensor lazy_weight = CopyToDevice(weight, device);
2139 torch::Tensor lazy_bias = CopyToDevice(bias, device);
2140 torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device);
2141 torch::Tensor lazy_running_var = CopyToDevice(running_var, device);
2142 torch::Tensor lazy_output = torch::instance_norm(
2143 lazy_input,
2144 lazy_weight,
2145 lazy_bias,
2146 lazy_running_mean,
2147 lazy_running_var,
2148 /*use_input_stats=*/true,
2149 momentum,
2150 eps,
2151 /*cudnn_enabled=*/false);
2152 AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
2153 });
2154}
2155
2156TEST_F(LazyOpsTest, TestLayerNorm) {
2157 torch::Tensor input = torch::rand(
2158 {20, 10, 10, 10},
2159 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2160 double eps = 1e-05;
2161 torch::Tensor undef;
2162 for (bool undef_weight : {true, false}) {
2163 for (int64_t normalized_size : {2, 3}) {
2164 std::vector<int64_t> normalized_shape(normalized_size, 10);
2165 torch::Tensor weight = torch::rand(
2166 normalized_shape,
2167 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2168 torch::Tensor bias = torch::rand(
2169 normalized_shape,
2170 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2171 torch::Tensor output = torch::layer_norm(
2172 input,
2173 normalized_shape,
2174 undef_weight ? undef : weight,
2175 undef_weight ? undef : bias,
2176 eps,
2177 /*cudnn_enabled=*/false);
2178 ForEachDevice([&](const torch::Device& device) {
2179 torch::Tensor lazy_input = CopyToDevice(input, device);
2180 torch::Tensor lazy_weight =
2181 undef_weight ? undef : CopyToDevice(weight, device);
2182 torch::Tensor lazy_bias =
2183 undef_weight ? undef : CopyToDevice(bias, device);
2184 torch::Tensor lazy_output = torch::layer_norm(
2185 lazy_input,
2186 normalized_shape,
2187 lazy_weight,
2188 lazy_bias,
2189 eps,
2190 /*cudnn_enabled=*/false);
2191 AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
2192 });
2193 }
2194 }
2195}
2196
2197TEST_F(LazyOpsTest, TestLayerNormBackward) {
2198 torch::Tensor input = torch::rand(
2199 {2, 3, 3, 3},
2200 torch::TensorOptions(torch::kFloat)
2201 .device(DefaultDevice())
2202 .requires_grad(true));
2203 double eps = 1e-05;
2204 for (bool undef_weight : {true, false}) {
2205 for (int64_t normalized_size : {2, 3}) {
2206 std::vector<int64_t> normalized_shape(normalized_size, 3);
2207 auto testfn =
2208 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
2209 return torch::layer_norm(
2210 /*input=*/inputs[0],
2211 normalized_shape,
2212 inputs[1],
2213 inputs[2],
2214 /*eps=*/eps,
2215 /*cudnn_enabled=*/false);
2216 };
2217 torch::Tensor weight = torch::rand(
2218 normalized_shape,
2219 torch::TensorOptions(torch::kFloat)
2220 .device(DefaultDevice())
2221 .requires_grad(true));
2222 torch::Tensor bias = torch::rand(
2223 normalized_shape,
2224 torch::TensorOptions(torch::kFloat)
2225 .device(DefaultDevice())
2226 .requires_grad(true));
2227 torch::Tensor undef;
2228 ForEachDevice([&](const torch::Device& device) {
2229 TestBackward(
2230 {input, undef_weight ? undef : weight, undef_weight ? undef : bias},
2231 device,
2232 testfn,
2233 /*rtol=*/1e-3,
2234 /*atol=*/1e-4,
2235 /*derivative_level=*/2);
2236 });
2237 }
2238 }
2239}
2240
2241TEST_F(LazyOpsTest, TestNuclearNorm) {
2242 torch::Tensor a = torch::rand(
2243 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2244 torch::Tensor b = torch::nuclear_norm(a);
2245 ForEachDevice([&](const torch::Device& device) {
2246 torch::Tensor lazy_a = CopyToDevice(a, device);
2247 torch::Tensor lazy_b = torch::nuclear_norm(lazy_a);
2248 AllClose(b, lazy_b);
2249 });
2250}
2251
2252TEST_F(LazyOpsTest, TestPairwiseDistance) {
2253 torch::Tensor x1 = torch::rand(
2254 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2255 torch::Tensor x2 = torch::rand(
2256 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2257 double eps = 1e-6;
2258 for (bool keepdim : {false, true}) {
2259 for (double p : {1, 2, 3, 4}) {
2260 ForEachDevice([&](const torch::Device& device) {
2261 torch::Tensor output =
2262 torch::pairwise_distance(x1, x2, p, eps, keepdim);
2263 torch::Tensor lazy_x1 = CopyToDevice(x1, device);
2264 torch::Tensor lazy_x2 = CopyToDevice(x2, device);
2265 torch::Tensor lazy_output =
2266 torch::pairwise_distance(lazy_x1, lazy_x2, p, eps, keepdim);
2267 AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-5);
2268 });
2269 }
2270 }
2271}
2272
2273TEST_F(LazyOpsTest, TestCosineSimilarity) {
2274 torch::Tensor x1 = torch::rand(
2275 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2276 torch::Tensor x2 = torch::rand(
2277 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2278 double eps = 1e-8;
2279 int rank = x1.dim();
2280 for (int dim = -rank; dim < rank; ++dim) {
2281 ForEachDevice([&](const torch::Device& device) {
2282 torch::Tensor output = torch::cosine_similarity(x1, x2, dim, eps);
2283 torch::Tensor lazy_x1 = CopyToDevice(x1, device);
2284 torch::Tensor lazy_x2 = CopyToDevice(x2, device);
2285 torch::Tensor lazy_output =
2286 torch::cosine_similarity(lazy_x1, lazy_x2, dim, eps);
2287 AllClose(output, lazy_output);
2288 });
2289 }
2290}
2291
2292TEST_F(LazyOpsTest, TestCosineEmbeddingLoss) {
2293 torch::Tensor input1 = torch::rand(
2294 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2295 torch::Tensor input2 = torch::rand(
2296 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2297 torch::Tensor target = torch::rand(
2298 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2299 for (torch::Reduction::Reduction reduction :
2300 {torch::Reduction::Mean, torch::Reduction::Sum}) {
2301 for (double margin : {0., 0.2}) {
2302 ForEachDevice([&](const torch::Device& device) {
2303 torch::Tensor output = torch::cosine_embedding_loss(
2304 input1, input2, target, margin, reduction);
2305 torch::Tensor lazy_input1 = CopyToDevice(input1, device);
2306 torch::Tensor lazy_input2 = CopyToDevice(input2, device);
2307 torch::Tensor lazy_target = CopyToDevice(target, device);
2308 torch::Tensor lazy_output = torch::cosine_embedding_loss(
2309 lazy_input1, lazy_input2, lazy_target, margin, reduction);
2310 AllClose(output, lazy_output);
2311 });
2312 }
2313 }
2314}
2315
2316TEST_F(LazyOpsTest, TestHingeEmbeddingLoss) {
2317 torch::Tensor input = torch::rand(
2318 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2319 torch::Tensor target = torch::rand(
2320 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2321 for (torch::Reduction::Reduction reduction :
2322 {torch::Reduction::Mean, torch::Reduction::Sum}) {
2323 for (double margin : {0., 0.2}) {
2324 ForEachDevice([&](const torch::Device& device) {
2325 torch::Tensor output =
2326 torch::hinge_embedding_loss(input, target, margin, reduction);
2327 torch::Tensor lazy_input = CopyToDevice(input, device);
2328 torch::Tensor lazy_target = CopyToDevice(target, device);
2329 torch::Tensor lazy_output = torch::hinge_embedding_loss(
2330 lazy_input, lazy_target, margin, reduction);
2331 AllClose(output, lazy_output);
2332 });
2333 }
2334 }
2335}
2336
2337TEST_F(LazyOpsTest, TestTripletMarginLoss) {
2338 torch::Tensor anchor = torch::rand(
2339 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2340 torch::Tensor positive = torch::abs(torch::rand(
2341 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
2342 torch::Tensor negative = torch::neg(torch::abs(torch::rand(
2343 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()))));
2344 double eps = 1e-6;
2345 for (double margin : {0., 0.2}) {
2346 for (double p : {1, 2, 3, 4}) {
2347 for (bool swap : {false, true}) {
2348 for (torch::Reduction::Reduction reduction :
2349 {torch::Reduction::Mean, torch::Reduction::Sum}) {
2350 ForEachDevice([&](const torch::Device& device) {
2351 torch::Tensor output = torch::triplet_margin_loss(
2352 anchor, positive, negative, margin, p, eps, swap, reduction);
2353 torch::Tensor lazy_anchor = CopyToDevice(anchor, device);
2354 torch::Tensor lazy_positive = CopyToDevice(positive, device);
2355 torch::Tensor lazy_negative = CopyToDevice(negative, device);
2356 torch::Tensor lazy_output = torch::triplet_margin_loss(
2357 lazy_anchor,
2358 lazy_positive,
2359 lazy_negative,
2360 margin,
2361 p,
2362 eps,
2363 swap,
2364 reduction);
2365 AllClose(output, lazy_output);
2366 });
2367 }
2368 }
2369 }
2370 }
2371}
2372
2373TEST_F(LazyOpsTest, TestBinaryCrossEntropy) {
2374 int batch = 10;
2375 int classes = 5;
2376 torch::Tensor input = torch::rand(
2377 {batch, classes},
2378 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2379 torch::Tensor target = torch::rand(
2380 {batch, classes},
2381 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2382 torch::Tensor weight = torch::rand(
2383 {batch, classes},
2384 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2385 torch::Tensor undef;
2386 for (torch::Reduction::Reduction reduction :
2387 {torch::Reduction::Mean,
2388 torch::Reduction::Sum,
2389 torch::Reduction::None}) {
2390 for (bool undef_weight : {false, true}) {
2391 ForEachDevice([&](const torch::Device& device) {
2392 torch::Tensor output = torch::binary_cross_entropy(
2393 input, target, undef_weight ? undef : weight, reduction);
2394 torch::Tensor lazy_input = CopyToDevice(input, device);
2395 torch::Tensor lazy_target = CopyToDevice(target, device);
2396 torch::Tensor lazy_weight =
2397 undef_weight ? undef : CopyToDevice(weight, device);
2398 torch::Tensor lazy_output = torch::binary_cross_entropy(
2399 lazy_input, lazy_target, lazy_weight, reduction);
2400 AllClose(output, lazy_output, /*rtol=*/1e-4, /*atol=*/1e-5);
2401 });
2402 }
2403 }
2404}
2405
2406TEST_F(LazyOpsTest, TestMarginRankingLoss) {
2407 torch::Tensor input1 = torch::rand(
2408 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2409 torch::Tensor input2 = torch::rand(
2410 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2411 torch::Tensor target = torch::rand(
2412 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2413 for (torch::Reduction::Reduction reduction :
2414 {torch::Reduction::Mean, torch::Reduction::Sum}) {
2415 for (double margin : {0., 0.2}) {
2416 ForEachDevice([&](const torch::Device& device) {
2417 torch::Tensor output = torch::margin_ranking_loss(
2418 input1, input2, target, margin, reduction);
2419 torch::Tensor lazy_input1 = CopyToDevice(input1, device);
2420 torch::Tensor lazy_input2 = CopyToDevice(input2, device);
2421 torch::Tensor lazy_target = CopyToDevice(target, device);
2422 torch::Tensor lazy_output = torch::margin_ranking_loss(
2423 lazy_input1, lazy_input2, lazy_target, margin, reduction);
2424 AllClose(output, lazy_output);
2425 });
2426 }
2427 }
2428}
2429
2430TEST_F(LazyOpsTest, TestBCEWithLogits) {
2431 int batch = 10;
2432 int classes = 5;
2433 torch::Tensor input = torch::rand(
2434 {batch, classes},
2435 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2436 torch::Tensor target = torch::rand(
2437 {batch, classes},
2438 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2439 torch::Tensor weight = torch::rand(
2440 {classes}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2441 torch::Tensor pos_weight = torch::rand(
2442 {classes}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2443 torch::Tensor undef;
2444 for (torch::Reduction::Reduction reduction :
2445 {torch::Reduction::Mean, torch::Reduction::Sum}) {
2446 for (bool undef_weight : {false, true}) {
2447 for (bool undef_pos_weight : {false, true}) {
2448 ForEachDevice([&](const torch::Device& device) {
2449 torch::Tensor output = torch::binary_cross_entropy_with_logits(
2450 input,
2451 target,
2452 undef_weight ? undef : weight,
2453 undef_pos_weight ? undef : pos_weight,
2454 reduction);
2455 torch::Tensor lazy_input = CopyToDevice(input, device);
2456 torch::Tensor lazy_target = CopyToDevice(target, device);
2457 torch::Tensor lazy_weight =
2458 undef_weight ? undef : CopyToDevice(weight, device);
2459 torch::Tensor lazy_pos_weight =
2460 undef_pos_weight ? undef : CopyToDevice(pos_weight, device);
2461 torch::Tensor lazy_output = torch::binary_cross_entropy_with_logits(
2462 lazy_input, lazy_target, lazy_weight, lazy_pos_weight, reduction);
2463 });
2464 }
2465 }
2466 }
2467}
2468
2469TEST_F(LazyOpsTest, TestKlDiv) {
2470 torch::Tensor input = torch::rand(
2471 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2472 torch::Tensor target = torch::rand(
2473 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2474 for (bool log_target : {true, false}) {
2475 for (torch::Reduction::Reduction reduction :
2476 {torch::Reduction::Mean, torch::Reduction::Sum}) {
2477 ForEachDevice([&](const torch::Device& device) {
2478 torch::Tensor output =
2479 torch::kl_div(input, target, reduction, log_target);
2480 torch::Tensor lazy_input = CopyToDevice(input, device);
2481 torch::Tensor lazy_target = CopyToDevice(target, device);
2482 torch::Tensor lazy_output =
2483 torch::kl_div(lazy_input, lazy_target, reduction, log_target);
2484 AllClose(output, lazy_output);
2485 });
2486 }
2487 }
2488}
2489
2490TEST_F(LazyOpsTest, TestProd) {
2491 torch::Tensor a = torch::rand(
2492 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2493 torch::Tensor b = torch::prod(a);
2494 ForEachDevice([&](const torch::Device& device) {
2495 torch::Tensor lazy_a = CopyToDevice(a, device);
2496 torch::Tensor lazy_b = torch::prod(lazy_a);
2497 AllClose(b, lazy_b);
2498 });
2499}
2500
2501TEST_F(LazyOpsTest, TestProdCast) {
2502 torch::Tensor a = torch::rand(
2503 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2504 torch::Tensor b = torch::prod(a, torch::kDouble);
2505 ForEachDevice([&](const torch::Device& device) {
2506 torch::Tensor lazy_a = CopyToDevice(a, device);
2507 torch::Tensor lazy_b = torch::prod(lazy_a, torch::kDouble);
2508 AllClose(b, lazy_b);
2509 });
2510}
2511
2512TEST_F(LazyOpsTest, TestProdInDim) {
2513 torch::Tensor a = torch::rand(
2514 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2515 int rank = a.dim();
2516 for (int dim = -rank; dim < rank; ++dim) {
2517 torch::Tensor b = torch::prod(a, dim);
2518 ForEachDevice([&](const torch::Device& device) {
2519 torch::Tensor lazy_a = CopyToDevice(a, device);
2520 torch::Tensor lazy_b = torch::prod(lazy_a, dim);
2521 AllClose(b, lazy_b);
2522 });
2523 }
2524}
2525
2526TEST_F(LazyOpsTest, TestProdInDimKeepCast) {
2527 torch::Tensor a = torch::rand(
2528 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2529 int rank = a.dim();
2530 for (int dim = -rank; dim < rank; ++dim) {
2531 torch::Tensor b = torch::prod(a, dim, /*keepdim=*/true, torch::kDouble);
2532 ForEachDevice([&](const torch::Device& device) {
2533 torch::Tensor lazy_a = CopyToDevice(a, device);
2534 torch::Tensor lazy_b =
2535 torch::prod(lazy_a, dim, /*keepdim=*/true, torch::kDouble);
2536 AllClose(b, lazy_b);
2537 });
2538 }
2539}
2540
2541TEST_F(LazyOpsTest, TestProdInDimKeep) {
2542 torch::Tensor a = torch::rand(
2543 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2544 int rank = a.dim();
2545 for (int dim = -rank; dim < rank; ++dim) {
2546 torch::Tensor b = torch::prod(a, dim, /*keepdim=*/true);
2547 ForEachDevice([&](const torch::Device& device) {
2548 torch::Tensor lazy_a = CopyToDevice(a, device);
2549 torch::Tensor lazy_b = torch::prod(lazy_a, dim, /*keepdim=*/true);
2550 AllClose(b, lazy_b);
2551 });
2552 }
2553}
2554
2555TEST_F(LazyOpsTest, TestCumSum) {
2556 torch::Tensor input = torch::rand(
2557 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2558 int rank = input.dim();
2559 for (int dim = -rank; dim < rank; ++dim) {
2560 torch::Tensor result = torch::cumsum(input, dim);
2561 ForEachDevice([&](const torch::Device& device) {
2562 torch::Tensor lazy_input = CopyToDevice(input, device);
2563 torch::Tensor lazy_result = torch::cumsum(lazy_input, dim);
2564 AllClose(result, lazy_result);
2565 });
2566 }
2567}
2568
2569TEST_F(LazyOpsTest, TestCumSumCast) {
2570 torch::Tensor input = torch::rand(
2571 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2572 int rank = input.dim();
2573 for (int dim = -rank; dim < rank; ++dim) {
2574 torch::Tensor result = torch::cumsum(input, dim, torch::kDouble);
2575 ForEachDevice([&](const torch::Device& device) {
2576 torch::Tensor lazy_input = CopyToDevice(input, device);
2577 torch::Tensor lazy_result =
2578 torch::cumsum(lazy_input, dim, torch::kDouble);
2579 AllClose(result, lazy_result);
2580 });
2581 }
2582}
2583
2584TEST_F(LazyOpsTest, TestCumSumLong) {
2585 torch::Tensor input = torch::randint(
2586 1000,
2587 {4, 3, 4},
2588 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
2589 int rank = input.dim();
2590 for (int dim = -rank; dim < rank; ++dim) {
2591 torch::Tensor result = torch::cumsum(input, dim);
2592 ForEachDevice([&](const torch::Device& device) {
2593 torch::Tensor lazy_input = CopyToDevice(input, device);
2594 torch::Tensor lazy_result = torch::cumsum(lazy_input, dim);
2595 AllEqual(result, lazy_result);
2596 });
2597 }
2598}
2599
2600TEST_F(LazyOpsTest, TestCumSumCastLong) {
2601 torch::Tensor input = torch::rand(
2602 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2603 int rank = input.dim();
2604 for (int dim = -rank; dim < rank; ++dim) {
2605 torch::Tensor result = torch::cumsum(input, dim, torch::kLong);
2606 ForEachDevice([&](const torch::Device& device) {
2607 torch::Tensor lazy_input = CopyToDevice(input, device);
2608 torch::Tensor lazy_result = torch::cumsum(lazy_input, dim, torch::kLong);
2609 AllEqual(result, lazy_result);
2610 });
2611 }
2612}
2613
2614TEST_F(LazyOpsTest, TestCumProd) {
2615 torch::Tensor input = torch::rand(
2616 {4, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2617 int rank = input.dim();
2618 for (int dim = -rank; dim < rank; ++dim) {
2619 torch::Tensor result = torch::cumprod(input, dim);
2620 ForEachDevice([&](const torch::Device& device) {
2621 torch::Tensor lazy_input = CopyToDevice(input, device);
2622 torch::Tensor lazy_result = torch::cumprod(lazy_input, dim);
2623 AllClose(result, lazy_result);
2624 });
2625 }
2626}
2627
2628TEST_F(LazyOpsTest, TestCumProdCast) {
2629 torch::Tensor input = torch::mul(
2630 torch::rand(
2631 {4, 3, 4},
2632 torch::TensorOptions(torch::kFloat).device(DefaultDevice())),
2633 10);
2634 int rank = input.dim();
2635 for (int dim = -rank; dim < rank; ++dim) {
2636 torch::Tensor result = torch::cumprod(input, dim, torch::kDouble);
2637 ForEachDevice([&](const torch::Device& device) {
2638 torch::Tensor lazy_input = CopyToDevice(input, device);
2639 torch::Tensor lazy_result =
2640 torch::cumprod(lazy_input, dim, torch::kDouble);
2641 AllClose(result, lazy_result);
2642 });
2643 }
2644}
2645
2646TEST_F(LazyOpsTest, TestCumProdLong) {
2647 torch::Tensor input = torch::randint(
2648 7, {2, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
2649 int rank = input.dim();
2650 for (int dim = -rank; dim < rank; ++dim) {
2651 torch::Tensor result = torch::cumsum(input, dim);
2652 ForEachDevice([&](const torch::Device& device) {
2653 torch::Tensor lazy_input = CopyToDevice(input, device);
2654 torch::Tensor lazy_result = torch::cumsum(lazy_input, dim);
2655 AllEqual(result, lazy_result);
2656 });
2657 }
2658}
2659
2660TEST_F(LazyOpsTest, TestCumProdCastLong) {
2661 torch::Tensor input =
2662 torch::rand(
2663 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
2664 7;
2665 int rank = input.dim();
2666 for (int dim = -rank; dim < rank; ++dim) {
2667 torch::Tensor result = torch::cumsum(input, dim, torch::kLong);
2668 ForEachDevice([&](const torch::Device& device) {
2669 torch::Tensor lazy_input = CopyToDevice(input, device);
2670 torch::Tensor lazy_result = torch::cumsum(lazy_input, dim, torch::kLong);
2671 AllEqual(result, lazy_result);
2672 });
2673 }
2674}
2675
2676TEST_F(LazyOpsTest, TestArgMin) {
2677 torch::Tensor a = torch::rand(
2678 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2679 torch::Tensor b = torch::argmin(a, c10::nullopt, /*keepdim=*/false);
2680 ForEachDevice([&](const torch::Device& device) {
2681 torch::Tensor lazy_a = CopyToDevice(a, device);
2682 torch::Tensor lazy_b =
2683 torch::argmin(lazy_a, c10::nullopt, /*keepdim=*/false);
2684 AllEqual(b, lazy_b);
2685 });
2686}
2687
2688TEST_F(LazyOpsTest, TestArgMinDim) {
2689 torch::Tensor a = torch::rand(
2690 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2691 for (int dim : {1, -2}) {
2692 torch::Tensor b = torch::argmin(a, dim, /*keepdim=*/false);
2693 ForEachDevice([&](const torch::Device& device) {
2694 torch::Tensor lazy_a = CopyToDevice(a, device);
2695 torch::Tensor lazy_b = torch::argmin(lazy_a, dim, /*keepdim=*/false);
2696 AllEqual(b, lazy_b);
2697 });
2698 }
2699}
2700
2701TEST_F(LazyOpsTest, TestArgMinDimKeep) {
2702 torch::Tensor a = torch::rand(
2703 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2704 for (int dim : {1, -2}) {
2705 torch::Tensor b = torch::argmin(a, dim, /*keepdim=*/true);
2706 ForEachDevice([&](const torch::Device& device) {
2707 torch::Tensor lazy_a = CopyToDevice(a, device);
2708 torch::Tensor lazy_b = torch::argmin(lazy_a, dim, /*keepdim=*/true);
2709 AllEqual(b, lazy_b);
2710 });
2711 }
2712}
2713
2714TEST_F(LazyOpsTest, TestArgMinSameValue) {
2715 torch::Tensor a = torch::ones(
2716 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2717 torch::Tensor b = torch::argmin(a);
2718 ForEachDevice([&](const torch::Device& device) {
2719 torch::Tensor lazy_a = CopyToDevice(a, device);
2720 torch::Tensor lazy_b = torch::argmin(lazy_a);
2721 AllEqual(b, lazy_b);
2722 });
2723}
2724
2725TEST_F(LazyOpsTest, TestArgMinWrapper) {
2726 torch::Tensor a = torch::rand(
2727 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2728 for (int dim : {1, -2}) {
2729 torch::Tensor b = torch::argmin(a, dim, /*keepdim=*/false);
2730 ForEachDevice([&](const torch::Device& device) {
2731 torch::Tensor lazy_a = CopyToDevice(a, device);
2732 torch::Tensor lazy_b = torch::argmin(lazy_a, dim, /*keepdim=*/false);
2733 AllEqual(b, lazy_b);
2734 });
2735 }
2736}
2737
2738TEST_F(LazyOpsTest, TestArgMax) {
2739 torch::Tensor a = torch::rand(
2740 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2741 torch::Tensor b = torch::argmax(a, c10::nullopt, /*keepdim=*/false);
2742 ForEachDevice([&](const torch::Device& device) {
2743 torch::Tensor lazy_a = CopyToDevice(a, device);
2744 torch::Tensor lazy_b =
2745 torch::argmax(lazy_a, c10::nullopt, /*keepdim=*/false);
2746 AllEqual(b, lazy_b);
2747 });
2748}
2749
2750TEST_F(LazyOpsTest, TestArgMaxDim) {
2751 torch::Tensor a = torch::rand(
2752 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2753 for (int dim : {1, -2}) {
2754 torch::Tensor b = torch::argmax(a, dim, /*keepdim=*/false);
2755 ForEachDevice([&](const torch::Device& device) {
2756 torch::Tensor lazy_a = CopyToDevice(a, device);
2757 torch::Tensor lazy_b = torch::argmax(lazy_a, dim, /*keepdim=*/false);
2758 AllEqual(b, lazy_b);
2759 });
2760 }
2761}
2762
2763TEST_F(LazyOpsTest, TestArgMaxDimKeep) {
2764 torch::Tensor a = torch::rand(
2765 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2766 for (int dim : {1, -2}) {
2767 torch::Tensor b = torch::argmax(a, dim, /*keepdim=*/true);
2768 ForEachDevice([&](const torch::Device& device) {
2769 torch::Tensor lazy_a = CopyToDevice(a, device);
2770 torch::Tensor lazy_b = torch::argmax(lazy_a, dim, /*keepdim=*/true);
2771 AllEqual(b, lazy_b);
2772 });
2773 }
2774}
2775
2776TEST_F(LazyOpsTest, TestArgMaxSameValue) {
2777 torch::Tensor a = torch::ones(
2778 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2779 torch::Tensor b = torch::argmax(a, c10::nullopt, /*keepdim=*/false);
2780 ForEachDevice([&](const torch::Device& device) {
2781 torch::Tensor lazy_a = CopyToDevice(a, device);
2782 torch::Tensor lazy_b =
2783 torch::argmax(lazy_a, c10::nullopt, /*keepdim=*/false);
2784 AllEqual(b, lazy_b);
2785 });
2786}
2787
2788TEST_F(LazyOpsTest, TestArgMaxWrapper) {
2789 torch::Tensor a = torch::rand(
2790 {4, 4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2791 for (int dim : {1, -2}) {
2792 torch::Tensor b = torch::argmax(a, dim, /*keepdim=*/false);
2793 ForEachDevice([&](const torch::Device& device) {
2794 torch::Tensor lazy_a = CopyToDevice(a, device);
2795 torch::Tensor lazy_b = torch::argmax(lazy_a, dim, /*keepdim=*/false);
2796 AllEqual(b, lazy_b);
2797 });
2798 }
2799}
2800
2801TEST_F(LazyOpsTest, TestAsin) {
2802 torch::Tensor a = torch::rand(
2803 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2804 torch::Tensor b = torch::asin(a);
2805 ForEachDevice([&](const torch::Device& device) {
2806 torch::Tensor lazy_a = CopyToDevice(a, device);
2807 torch::Tensor lazy_b = torch::asin(lazy_a);
2808 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2809 });
2810}
2811
2812TEST_F(LazyOpsTest, TestAsinh) {
2813 torch::Tensor a = torch::rand(
2814 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2815 torch::Tensor b = torch::asinh(a);
2816 ForEachDevice([&](const torch::Device& device) {
2817 torch::Tensor lazy_a = CopyToDevice(a, device);
2818 torch::Tensor lazy_b = torch::asinh(lazy_a);
2819 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2820 });
2821}
2822
2823TEST_F(LazyOpsTest, TestAsinhInPlace) {
2824 torch::Tensor a = torch::rand(
2825 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2826 ForEachDevice([&](const torch::Device& device) {
2827 torch::Tensor lazy_a = CopyToDevice(a, device);
2828 torch::Tensor b = torch::asinh_(a);
2829 torch::Tensor lazy_b = torch::asinh_(lazy_a);
2830 AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-5);
2831 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2832 });
2833}
2834
2835TEST_F(LazyOpsTest, TestSin) {
2836 torch::Tensor a = torch::rand(
2837 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2838 torch::Tensor b = torch::sin(a);
2839 ForEachDevice([&](const torch::Device& device) {
2840 torch::Tensor lazy_a = CopyToDevice(a, device);
2841 torch::Tensor lazy_b = torch::sin(lazy_a);
2842 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2843 });
2844}
2845
2846TEST_F(LazyOpsTest, TestSinh) {
2847 torch::Tensor a = torch::rand(
2848 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2849 torch::Tensor b = torch::sinh(a);
2850 ForEachDevice([&](const torch::Device& device) {
2851 torch::Tensor lazy_a = CopyToDevice(a, device);
2852 torch::Tensor lazy_b = torch::sinh(lazy_a);
2853 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2854 });
2855}
2856
2857TEST_F(LazyOpsTest, TestAcos) {
2858 torch::Tensor a = torch::rand(
2859 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2860 torch::Tensor b = torch::acos(a);
2861 ForEachDevice([&](const torch::Device& device) {
2862 torch::Tensor lazy_a = CopyToDevice(a, device);
2863 torch::Tensor lazy_b = torch::acos(lazy_a);
2864 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2865 });
2866}
2867
2868TEST_F(LazyOpsTest, TestAcosh) {
2869 torch::Tensor a =
2870 torch::rand(
2871 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
2872 100;
2873 torch::Tensor b = torch::acosh(a);
2874 ForEachDevice([&](const torch::Device& device) {
2875 torch::Tensor lazy_a = CopyToDevice(a, device);
2876 torch::Tensor lazy_b = torch::acosh(lazy_a);
2877 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2878 });
2879}
2880
2881TEST_F(LazyOpsTest, TestAcoshInPlace) {
2882 torch::Tensor a =
2883 torch::rand(
2884 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
2885 100;
2886 ForEachDevice([&](const torch::Device& device) {
2887 torch::Tensor lazy_a = CopyToDevice(a, device);
2888 torch::Tensor b = torch::acosh_(a);
2889 torch::Tensor lazy_b = torch::acosh_(lazy_a);
2890 AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-5);
2891 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2892 });
2893}
2894
2895TEST_F(LazyOpsTest, TestCos) {
2896 torch::Tensor a = torch::rand(
2897 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2898 torch::Tensor b = torch::cos(a);
2899 ForEachDevice([&](const torch::Device& device) {
2900 torch::Tensor lazy_a = CopyToDevice(a, device);
2901 torch::Tensor lazy_b = torch::cos(lazy_a);
2902 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2903 });
2904}
2905
2906TEST_F(LazyOpsTest, TestCosh) {
2907 torch::Tensor a = torch::rand(
2908 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2909 torch::Tensor b = torch::cosh(a);
2910 ForEachDevice([&](const torch::Device& device) {
2911 torch::Tensor lazy_a = CopyToDevice(a, device);
2912 torch::Tensor lazy_b = torch::cosh(lazy_a);
2913 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2914 });
2915}
2916
2917TEST_F(LazyOpsTest, TestAtan) {
2918 torch::Tensor a = torch::rand(
2919 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2920 torch::Tensor b = torch::atan(a);
2921 ForEachDevice([&](const torch::Device& device) {
2922 torch::Tensor lazy_a = CopyToDevice(a, device);
2923 torch::Tensor lazy_b = torch::atan(lazy_a);
2924 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2925 });
2926}
2927
2928TEST_F(LazyOpsTest, TestAtanh) {
2929 torch::Tensor a = torch::rand(
2930 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2931 torch::Tensor b = torch::atanh(a);
2932 ForEachDevice([&](const torch::Device& device) {
2933 torch::Tensor lazy_a = CopyToDevice(a, device);
2934 torch::Tensor lazy_b = torch::atanh(lazy_a);
2935 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2936 });
2937}
2938
2939TEST_F(LazyOpsTest, TestAtanhInPlace) {
2940 torch::Tensor a = torch::rand(
2941 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2942 ForEachDevice([&](const torch::Device& device) {
2943 torch::Tensor lazy_a = CopyToDevice(a, device);
2944 torch::Tensor b = torch::atanh_(a);
2945 torch::Tensor lazy_b = torch::atanh_(lazy_a);
2946 AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-5);
2947 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2948 });
2949}
2950
2951TEST_F(LazyOpsTest, TestAtan2) {
2952 torch::Tensor a = torch::randn(
2953 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2954 torch::Tensor b = torch::randn(
2955 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2956 torch::Tensor c = torch::atan2(a, b);
2957 ForEachDevice([&](const torch::Device& device) {
2958 torch::Tensor lazy_a = CopyToDevice(a, device);
2959 torch::Tensor lazy_b = CopyToDevice(b, device);
2960 torch::Tensor lazy_c = torch::atan2(lazy_a, lazy_b);
2961 AllClose(c, lazy_c, /*rtol=*/1e-3, /*atol=*/1e-5);
2962 });
2963}
2964
2965TEST_F(LazyOpsTest, TestTan) {
2966 torch::Tensor a = torch::rand(
2967 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2968 torch::Tensor b = torch::tan(a);
2969 ForEachDevice([&](const torch::Device& device) {
2970 torch::Tensor lazy_a = CopyToDevice(a, device);
2971 torch::Tensor lazy_b = torch::tan(lazy_a);
2972 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2973 });
2974}
2975
2976TEST_F(LazyOpsTest, TestTanh) {
2977 torch::Tensor a = torch::rand(
2978 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2979 torch::Tensor b = torch::tanh(a);
2980 ForEachDevice([&](const torch::Device& device) {
2981 torch::Tensor lazy_a = CopyToDevice(a, device);
2982 torch::Tensor lazy_b = torch::tanh(lazy_a);
2983 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
2984 });
2985}
2986
2987TEST_F(LazyOpsTest, TestClampMinMax) {
2988 torch::Tensor a = torch::rand(
2989 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
2990 torch::Scalar min_val(0.311);
2991 torch::Scalar max_val(0.409);
2992 torch::Tensor b = torch::clamp(a, min_val, max_val);
2993 ForEachDevice([&](const torch::Device& device) {
2994 torch::Tensor lazy_a = CopyToDevice(a, device);
2995 torch::Tensor lazy_b = torch::clamp(lazy_a, min_val, max_val);
2996 AllClose(b, lazy_b);
2997 });
2998}
2999
3000TEST_F(LazyOpsTest, TestClampMin) {
3001 torch::Tensor a = torch::rand(
3002 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3003 torch::Scalar min_val(0.311);
3004 torch::Tensor b = torch::clamp(a, min_val, c10::nullopt);
3005 ForEachDevice([&](const torch::Device& device) {
3006 torch::Tensor lazy_a = CopyToDevice(a, device);
3007 torch::Tensor lazy_b = torch::clamp(lazy_a, min_val, c10::nullopt);
3008 AllClose(b, lazy_b);
3009 });
3010}
3011
3012TEST_F(LazyOpsTest, TestClampMax) {
3013 torch::Tensor a = torch::rand(
3014 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3015 torch::Scalar max_val(0.409);
3016 torch::Tensor b = torch::clamp(a, c10::nullopt, max_val);
3017 ForEachDevice([&](const torch::Device& device) {
3018 torch::Tensor lazy_a = CopyToDevice(a, device);
3019 torch::Tensor lazy_b = torch::clamp(lazy_a, c10::nullopt, max_val);
3020 AllClose(b, lazy_b);
3021 });
3022}
3023
3024TEST_F(LazyOpsTest, TestClampMinExplicit) {
3025 torch::Tensor a = torch::rand(
3026 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3027 torch::Scalar min_val(0.311);
3028 torch::Tensor b = torch::clamp_min(a, min_val);
3029 ForEachDevice([&](const torch::Device& device) {
3030 torch::Tensor lazy_a = CopyToDevice(a, device);
3031 torch::Tensor lazy_b = torch::clamp_min(lazy_a, min_val);
3032 AllClose(b, lazy_b);
3033 });
3034}
3035
3036TEST_F(LazyOpsTest, TestClampMaxExplicit) {
3037 torch::Tensor a = torch::rand(
3038 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3039 torch::Scalar max_val(0.409);
3040 torch::Tensor b = torch::clamp_max(a, max_val);
3041 ForEachDevice([&](const torch::Device& device) {
3042 torch::Tensor lazy_a = CopyToDevice(a, device);
3043 torch::Tensor lazy_b = torch::clamp_max(lazy_a, max_val);
3044 AllClose(b, lazy_b);
3045 });
3046}
3047
3048TEST_F(LazyOpsTest, TestClampMinExplicitInPlace) {
3049 torch::Tensor a = torch::rand(
3050 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3051 torch::Scalar min_val(0.311);
3052 ForEachDevice([&](const torch::Device& device) {
3053 torch::Tensor lazy_a = CopyToDevice(a, device);
3054 torch::Tensor b = torch::clamp_min_(a, min_val);
3055 torch::Tensor lazy_b = torch::clamp_min_(lazy_a, min_val);
3056 AllClose(a, lazy_a);
3057 AllClose(b, lazy_b);
3058 });
3059}
3060
3061TEST_F(LazyOpsTest, TestClampMaxExplicitInPlace) {
3062 torch::Tensor a = torch::rand(
3063 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3064 torch::Scalar max_val(0.409);
3065 ForEachDevice([&](const torch::Device& device) {
3066 torch::Tensor lazy_a = CopyToDevice(a, device);
3067 torch::Tensor b = torch::clamp_max_(a, max_val);
3068 torch::Tensor lazy_b = torch::clamp_max_(lazy_a, max_val);
3069 AllClose(a, lazy_a);
3070 AllClose(b, lazy_b);
3071 });
3072}
3073
3074TEST_F(LazyOpsTest, TestCeil) {
3075 torch::Tensor a =
3076 torch::randn(
3077 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3078 100.0;
3079 torch::Tensor b = torch::ceil(a);
3080 ForEachDevice([&](const torch::Device& device) {
3081 torch::Tensor lazy_a = CopyToDevice(a, device);
3082 torch::Tensor lazy_b = torch::ceil(lazy_a);
3083 AllClose(b, lazy_b);
3084 });
3085}
3086
3087TEST_F(LazyOpsTest, TestFloor) {
3088 torch::Tensor a =
3089 torch::randn(
3090 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3091 100.0;
3092 torch::Tensor b = torch::floor(a);
3093 ForEachDevice([&](const torch::Device& device) {
3094 torch::Tensor lazy_a = CopyToDevice(a, device);
3095 torch::Tensor lazy_b = torch::floor(lazy_a);
3096 AllClose(b, lazy_b);
3097 });
3098}
3099
3100TEST_F(LazyOpsTest, TestRound) {
3101 torch::Tensor a = torch::cat(
3102 {torch::randn(
3103 {8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3104 100.0,
3105 // Special case: 0.5, -0.5. lazy::Round impl rounds to -1/1 whereas
3106 // lazy::RoundToEven properly implements bankers rounding.
3107 torch::tensor(
3108 {-0.5, 0.5},
3109 torch::TensorOptions(torch::kFloat).device(DefaultDevice()))},
3110 0);
3111 torch::Tensor b = torch::round(a);
3112 ForEachDevice([&](const torch::Device& device) {
3113 torch::Tensor lazy_a = CopyToDevice(a, device);
3114 torch::Tensor lazy_b = torch::round(lazy_a);
3115 AllClose(b, lazy_b);
3116 });
3117}
3118
3119TEST_F(LazyOpsTest, TestTrunc) {
3120 torch::Tensor a =
3121 torch::randn(
3122 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3123 100.0;
3124 torch::Tensor b = torch::trunc(a);
3125 ForEachDevice([&](const torch::Device& device) {
3126 torch::Tensor lazy_a = CopyToDevice(a, device);
3127 torch::Tensor lazy_b = torch::trunc(lazy_a);
3128 AllClose(b, lazy_b);
3129 });
3130}
3131
3132TEST_F(LazyOpsTest, TestFrac) {
3133 torch::Tensor a =
3134 torch::randn(
3135 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3136 100.0;
3137 torch::Tensor b = torch::frac(a);
3138 ForEachDevice([&](const torch::Device& device) {
3139 torch::Tensor lazy_a = CopyToDevice(a, device);
3140 torch::Tensor lazy_b = torch::frac(lazy_a);
3141 AllClose(b, lazy_b);
3142 });
3143}
3144
3145TEST_F(LazyOpsTest, TestNeg) {
3146 torch::Tensor a = torch::rand(
3147 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3148 torch::Tensor b = torch::neg(a);
3149 ForEachDevice([&](const torch::Device& device) {
3150 torch::Tensor lazy_a = CopyToDevice(a, device);
3151 torch::Tensor lazy_b = torch::neg(lazy_a);
3152 AllClose(b, lazy_b);
3153 });
3154}
3155
3156TEST_F(LazyOpsTest, TestBitwiseNot) {
3157 std::vector<torch::ScalarType> types(
3158 {torch::kByte, torch::kChar, torch::kShort, torch::kInt, torch::kLong});
3159
3160 ForEachDevice([&](const torch::Device& device) {
3161 for (auto type : types) {
3162 torch::Tensor a =
3163 torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
3164 torch::Tensor b = torch::bitwise_not(a);
3165 torch::Tensor lazy_a = CopyToDevice(a, device);
3166 torch::Tensor lazy_b = torch::bitwise_not(lazy_a);
3167 AllEqual(b, lazy_b);
3168 }
3169 });
3170}
3171
3172TEST_F(LazyOpsTest, TestBitwiseNotInPlace) {
3173 std::vector<torch::ScalarType> types(
3174 {torch::kByte, torch::kChar, torch::kShort, torch::kInt, torch::kLong});
3175
3176 ForEachDevice([&](const torch::Device& device) {
3177 for (auto type : types) {
3178 torch::Tensor a =
3179 torch::randint(0, 63, {2, 2}, torch::TensorOptions(type));
3180 torch::Tensor lazy_a = CopyToDevice(a, device);
3181 a.bitwise_not_();
3182 lazy_a.bitwise_not_();
3183 AllEqual(a, lazy_a);
3184 }
3185 });
3186}
3187
3188TEST_F(LazyOpsTest, TestSign) {
3189 torch::Tensor a =
3190 torch::randn(
3191 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
3192 100.0;
3193 torch::Tensor b = torch::sign(a);
3194 ForEachDevice([&](const torch::Device& device) {
3195 torch::Tensor lazy_a = CopyToDevice(a, device);
3196 torch::Tensor lazy_b = torch::sign(lazy_a);
3197 AllClose(b, lazy_b);
3198 });
3199}
3200
3201TEST_F(LazyOpsTest, TestSignByte) {
3202 torch::Tensor a = torch::randint(
3203 256, {2, 2}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
3204 torch::Tensor b = torch::sign(a);
3205 ForEachDevice([&](const torch::Device& device) {
3206 torch::Tensor lazy_a = CopyToDevice(a, device);
3207 torch::Tensor lazy_b = torch::sign(lazy_a);
3208 AllEqual(b, lazy_b);
3209 });
3210}
3211
3212TEST_F(LazyOpsTest, TestAbs) {
3213 torch::Tensor a = torch::randn(
3214 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3215 torch::Tensor b = torch::abs(a);
3216 ForEachDevice([&](const torch::Device& device) {
3217 torch::Tensor lazy_a = CopyToDevice(a, device);
3218 torch::Tensor lazy_b = torch::abs(lazy_a);
3219 AllClose(b, lazy_b);
3220 });
3221}
3222
3223TEST_F(LazyOpsTest, TestAbsByte) {
3224 torch::Tensor a = torch::randint(
3225 256, {2, 2}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
3226 torch::Tensor b = torch::abs(a);
3227 ForEachDevice([&](const torch::Device& device) {
3228 torch::Tensor lazy_a = CopyToDevice(a, device);
3229 torch::Tensor lazy_b = torch::abs(lazy_a);
3230 AllEqual(b, lazy_b);
3231 });
3232}
3233
3234TEST_F(LazyOpsTest, TestEmptyLike) {
3235 torch::Tensor a = torch::rand(
3236 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3237 torch::Tensor b = torch::empty_like(a);
3238 ForEachDevice([&](const torch::Device& device) {
3239 torch::Tensor lazy_a = CopyToDevice(a, device);
3240 torch::Tensor lazy_b = torch::empty_like(lazy_a);
3241 EXPECT_EQ(b.sizes(), lazy_b.sizes());
3242 });
3243}
3244
3245TEST_F(LazyOpsTest, TestEmptyLikeOptions) {
3246 torch::Tensor a = torch::rand(
3247 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3248 torch::Tensor b = torch::empty_like(
3249 a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3250 ForEachDevice([&](const torch::Device& device) {
3251 torch::Tensor lazy_a = CopyToDevice(a, device);
3252 torch::Tensor lazy_b = torch::empty_like(
3253 lazy_a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3254 EXPECT_EQ(b.sizes(), lazy_b.sizes());
3255 });
3256}
3257
3258TEST_F(LazyOpsTest, TestEmpty) {
3259 torch::Tensor a = torch::zeros(
3260 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3261 ForEachDevice([&](const torch::Device& device) {
3262 torch::Tensor lazy_a = torch::empty(
3263 {2, 2}, torch::TensorOptions(torch::kFloat).device(device));
3264 EXPECT_EQ(a.sizes(), lazy_a.sizes());
3265 });
3266}
3267
3268TEST_F(LazyOpsTest, TestZeroInPlace) {
3269 torch::Tensor input = torch::ones(
3270 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3271
3272 ForEachDevice([&](const torch::Device& device) {
3273 torch::Tensor lazyInput = CopyToDevice(input, device);
3274 auto& output = torch::zero_(input);
3275 auto& lazyOutput = torch::zero_(lazyInput);
3276 AllClose(output, lazyOutput);
3277 });
3278}
3279
3280TEST_F(LazyOpsTest, TestZerosLike) {
3281 torch::Tensor a = torch::rand(
3282 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3283 torch::Tensor b = torch::zeros_like(a);
3284 ForEachDevice([&](const torch::Device& device) {
3285 torch::Tensor lazy_a = CopyToDevice(a, device);
3286 torch::Tensor lazy_b = torch::zeros_like(lazy_a);
3287 AllClose(a, lazy_a);
3288 });
3289}
3290
3291TEST_F(LazyOpsTest, TestZerosLikeOptions) {
3292 torch::Tensor a = torch::rand(
3293 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3294 torch::Tensor b = torch::zeros_like(
3295 a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3296 ForEachDevice([&](const torch::Device& device) {
3297 torch::Tensor lazy_a = CopyToDevice(a, device);
3298 torch::Tensor lazy_b = torch::zeros_like(
3299 lazy_a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3300 AllClose(a, lazy_a);
3301 });
3302}
3303
3304TEST_F(LazyOpsTest, TestZeros) {
3305 torch::Tensor a = torch::zeros(
3306 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3307 ForEachDevice([&](const torch::Device& device) {
3308 torch::Tensor lazy_a = torch::zeros(
3309 {2, 2}, torch::TensorOptions(torch::kFloat).device(device));
3310 AllClose(a, lazy_a);
3311 });
3312}
3313
3314TEST_F(LazyOpsTest, TestOnes) {
3315 torch::Tensor a = torch::ones(
3316 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3317 ForEachDevice([&](const torch::Device& device) {
3318 torch::Tensor lazy_a =
3319 torch::ones({2, 2}, torch::TensorOptions(torch::kFloat).device(device));
3320 AllClose(a, lazy_a);
3321 });
3322}
3323
3324TEST_F(LazyOpsTest, TestOnesLike) {
3325 torch::Tensor a = torch::rand(
3326 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3327 torch::Tensor b = torch::ones_like(a);
3328 ForEachDevice([&](const torch::Device& device) {
3329 torch::Tensor lazy_a = CopyToDevice(a, device);
3330 torch::Tensor lazy_b = torch::ones_like(lazy_a);
3331 AllClose(a, lazy_a);
3332 });
3333}
3334
3335TEST_F(LazyOpsTest, TestOnesLikeOptions) {
3336 torch::Tensor a = torch::rand(
3337 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3338 torch::Tensor b = torch::ones_like(
3339 a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3340 ForEachDevice([&](const torch::Device& device) {
3341 torch::Tensor lazy_a = CopyToDevice(a, device);
3342 torch::Tensor lazy_b = torch::ones_like(
3343 lazy_a, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3344 AllClose(a, lazy_a);
3345 });
3346}
3347
3348TEST_F(LazyOpsTest, TestFull) {
3349 torch::Tensor a = torch::full(
3350 {2, 2},
3351 3.1165,
3352 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3353 ForEachDevice([&](const torch::Device& device) {
3354 torch::Tensor lazy_a = torch::full(
3355 {2, 2}, 3.1165, torch::TensorOptions(torch::kFloat).device(device));
3356 AllClose(a, lazy_a);
3357 });
3358}
3359
3360TEST_F(LazyOpsTest, TestFullLike) {
3361 torch::Tensor a = torch::rand(
3362 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3363 torch::Tensor b = torch::full_like(a, 3.1165);
3364 ForEachDevice([&](const torch::Device& device) {
3365 torch::Tensor lazy_a = CopyToDevice(a, device);
3366 torch::Tensor lazy_b = torch::full_like(lazy_a, 3.1165);
3367 AllClose(a, lazy_a);
3368 });
3369}
3370
3371TEST_F(LazyOpsTest, TestFullLikeOptions) {
3372 torch::Tensor a = torch::rand(
3373 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3374 torch::Tensor b = torch::full_like(
3375 a, 3.1165, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3376 ForEachDevice([&](const torch::Device& device) {
3377 torch::Tensor lazy_a = CopyToDevice(a, device);
3378 torch::Tensor lazy_b = torch::full_like(
3379 lazy_a,
3380 3.1165,
3381 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3382 AllClose(a, lazy_a);
3383 });
3384}
3385
3386TEST_F(LazyOpsTest, TestARange) {
3387 for (auto& ranges : std::vector<std::vector<float>>{
3388 {0.0, 100.0, 0.5}, {0.0, -100.0, -0.5}}) {
3389 torch::Tensor a = torch::arange(
3390 ranges[0],
3391 ranges[1],
3392 ranges[2],
3393 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3394 ForEachDevice([&](const torch::Device& device) {
3395 torch::Tensor lazy_a = torch::arange(
3396 ranges[0],
3397 ranges[1],
3398 ranges[2],
3399 torch::TensorOptions(torch::kFloat).device(device));
3400 AllClose(a, lazy_a);
3401 });
3402 }
3403}
3404
3405TEST_F(LazyOpsTest, TestARangeOut) {
3406 torch::Tensor a = torch::randn(
3407 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3408 for (auto& ranges : std::vector<std::vector<float>>{
3409 {0.0, 100.0, 0.5}, {0.0, -100.0, -0.5}}) {
3410 torch::Tensor b = torch::arange_out(a, ranges[0], ranges[1], ranges[2]);
3411 ForEachDevice([&](const torch::Device& device) {
3412 torch::Tensor lazy_a = CopyToDevice(a, device);
3413 torch::Tensor lazy_b =
3414 torch::arange_out(lazy_a, ranges[0], ranges[1], ranges[2]);
3415 AllClose(b, lazy_b);
3416 });
3417 }
3418}
3419
3420TEST_F(LazyOpsTest, TestDimARange) {
3421 torch::Tensor like = torch::rand(
3422 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3423 torch::Tensor a = torch::_dim_arange(like, 1);
3424 ForEachDevice([&](const torch::Device& device) {
3425 torch::Tensor lazy_like = CopyToDevice(like, device);
3426 torch::Tensor lazy_a = torch::_dim_arange(lazy_like, 1);
3427 AllClose(a, lazy_a);
3428 });
3429}
3430
3431TEST_F(LazyOpsTest, TestBartlettWindow) {
3432 int window_length = 10;
3433 for (bool periodic : {false, true}) {
3434 ForEachDevice([&](const torch::Device& device) {
3435 torch::Tensor output = torch::bartlett_window(
3436 window_length,
3437 periodic,
3438 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3439
3440 torch::Tensor lazy_output = torch::bartlett_window(
3441 window_length,
3442 periodic,
3443 torch::TensorOptions(torch::kFloat).device(device));
3444 AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-7);
3445 });
3446 }
3447}
3448
3449TEST_F(LazyOpsTest, TestBlackmanWindow) {
3450 int window_length = 10;
3451 for (bool periodic : {false, true}) {
3452 ForEachDevice([&](const torch::Device& device) {
3453 torch::Tensor output = torch::blackman_window(
3454 window_length,
3455 periodic,
3456 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3457 torch::Tensor lazy_output = torch::blackman_window(
3458 window_length,
3459 periodic,
3460 torch::TensorOptions(torch::kFloat).device(device));
3461 AllClose(output, lazy_output, /*rtol=*/1e-5, /*atol=*/1e-7);
3462 });
3463 }
3464}
3465
3466TEST_F(LazyOpsTest, TestHammingWindow) {
3467 double alpha = 0.54;
3468 double beta = 0.46;
3469 int window_length = 10;
3470 for (bool periodic : {false, true}) {
3471 ForEachDevice([&](const torch::Device& device) {
3472 torch::Tensor output = torch::hamming_window(
3473 window_length,
3474 periodic,
3475 alpha,
3476 beta,
3477 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3478 torch::Tensor lazy_output = torch::hamming_window(
3479 window_length,
3480 periodic,
3481 alpha,
3482 beta,
3483 torch::TensorOptions(torch::kFloat).device(device));
3484 AllClose(output, lazy_output);
3485 });
3486 }
3487}
3488
3489TEST_F(LazyOpsTest, TestHannWindow) {
3490 int window_length = 10;
3491 for (bool periodic : {false, true}) {
3492 ForEachDevice([&](const torch::Device& device) {
3493 torch::Tensor output = torch::hann_window(
3494 window_length,
3495 periodic,
3496 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3497 torch::Tensor lazy_output = torch::hann_window(
3498 window_length,
3499 periodic,
3500 torch::TensorOptions(torch::kFloat).device(device));
3501 AllClose(output, lazy_output);
3502 });
3503 }
3504}
3505
3506TEST_F(LazyOpsTest, TestLogSigmoid) {
3507 torch::Tensor a = torch::empty(
3508 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3509 a.uniform_(-1.0, 1.0);
3510 torch::Tensor b = torch::log_sigmoid(a);
3511 ForEachDevice([&](const torch::Device& device) {
3512 torch::Tensor lazy_a = CopyToDevice(a, device);
3513 torch::Tensor lazy_b = torch::log_sigmoid(lazy_a);
3514 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
3515 });
3516}
3517
3518TEST_F(LazyOpsTest, TestLogSigmoidForward) {
3519 torch::Tensor a = torch::empty(
3520 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3521 a.uniform_(-1.0, 1.0);
3522 auto tuple = torch::log_sigmoid_forward(a);
3523 ForEachDevice([&](const torch::Device& device) {
3524 torch::Tensor lazy_a = CopyToDevice(a, device);
3525 auto lazy_tuple = torch::log_sigmoid_forward(lazy_a);
3526 AllClose(
3527 std::get<0>(tuple),
3528 std::get<0>(lazy_tuple),
3529 /*rtol=*/1e-3,
3530 /*atol=*/1e-5);
3531 AllClose(
3532 std::get<1>(tuple),
3533 std::get<1>(lazy_tuple),
3534 /*rtol=*/1e-3,
3535 /*atol=*/1e-5);
3536 });
3537}
3538
3539TEST_F(LazyOpsTest, TestLogsumexp) {
3540 torch::Tensor a = torch::rand(
3541 {3, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3542 for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
3543 for (bool keepdim : {false, true}) {
3544 torch::Tensor b = torch::logsumexp(a, dims, keepdim);
3545 ForEachDevice([&](const torch::Device& device) {
3546 torch::Tensor lazy_a = CopyToDevice(a, device);
3547 torch::Tensor lazy_b = torch::logsumexp(lazy_a, dims, keepdim);
3548 AllClose(b, lazy_b);
3549 });
3550 }
3551 }
3552}
3553
3554TEST_F(LazyOpsTest, TestSiLU) {
3555 torch::Tensor a = torch::rand(
3556 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3557 torch::Tensor b = torch::silu(a);
3558 ForEachDevice([&](const torch::Device& device) {
3559 torch::Tensor lazy_a = CopyToDevice(a, device);
3560 torch::Tensor lazy_b = torch::silu(lazy_a);
3561 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
3562 });
3563 ExpectCounterChanged("lazy::silu_out", GetIgnoredCounters());
3564}
3565
3566TEST_F(LazyOpsTest, TestSigmoid) {
3567 torch::Tensor a = torch::rand(
3568 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3569 torch::Tensor b = torch::sigmoid(a);
3570 ForEachDevice([&](const torch::Device& device) {
3571 torch::Tensor lazy_a = CopyToDevice(a, device);
3572 torch::Tensor lazy_b = torch::sigmoid(lazy_a);
3573 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
3574 });
3575}
3576
3577TEST_F(LazyOpsTest, TestMatmul_1x1) {
3578 torch::Tensor a = torch::rand(
3579 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3580 torch::Tensor b = torch::rand(
3581 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3582 torch::Tensor c = torch::matmul(a, b);
3583 ForEachDevice([&](const torch::Device& device) {
3584 torch::Tensor lazy_a = CopyToDevice(a, device);
3585 torch::Tensor lazy_b = CopyToDevice(b, device);
3586 torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3587 AllClose(c, lazy_c);
3588 });
3589}
3590
3591TEST_F(LazyOpsTest, TestMatmul_2x1) {
3592 torch::Tensor a = torch::rand(
3593 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3594 torch::Tensor b = torch::rand(
3595 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3596 torch::Tensor c = torch::matmul(a, b);
3597 ForEachDevice([&](const torch::Device& device) {
3598 torch::Tensor lazy_a = CopyToDevice(a, device);
3599 torch::Tensor lazy_b = CopyToDevice(b, device);
3600 torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3601 AllClose(c, lazy_c);
3602 });
3603}
3604
3605TEST_F(LazyOpsTest, TestMatmul_1x2) {
3606 torch::Tensor a = torch::rand(
3607 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3608 torch::Tensor b = torch::rand(
3609 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3610 torch::Tensor c = torch::matmul(a, b);
3611 ForEachDevice([&](const torch::Device& device) {
3612 torch::Tensor lazy_a = CopyToDevice(a, device);
3613 torch::Tensor lazy_b = CopyToDevice(b, device);
3614 torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3615 AllClose(c, lazy_c);
3616 });
3617}
3618
3619TEST_F(LazyOpsTest, TestMatmul_2x2) {
3620 torch::Tensor a = torch::rand(
3621 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3622 torch::Tensor b = torch::rand(
3623 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3624 torch::Tensor c = torch::matmul(a, b);
3625 ForEachDevice([&](const torch::Device& device) {
3626 torch::Tensor lazy_a = CopyToDevice(a, device);
3627 torch::Tensor lazy_b = CopyToDevice(b, device);
3628 torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3629 AllClose(c, lazy_c, /*rtol=*/1e-3, /*atol=*/1e-4);
3630 });
3631}
3632
3633TEST_F(LazyOpsTest, TestMatmulBcast) {
3634 torch::Tensor a = torch::rand(
3635 {4, 2, 3, 2, 4},
3636 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3637 torch::Tensor b = torch::rand(
3638 {2, 1, 4, 3},
3639 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3640 torch::Tensor c = torch::matmul(a, b);
3641 ForEachDevice([&](const torch::Device& device) {
3642 torch::Tensor lazy_a = CopyToDevice(a, device);
3643 torch::Tensor lazy_b = CopyToDevice(b, device);
3644 torch::Tensor lazy_c = torch::matmul(lazy_a, lazy_b);
3645 AllClose(c, lazy_c);
3646 });
3647}
3648
3649TEST_F(LazyOpsTest, TestDot) {
3650 torch::Tensor a = torch::rand(
3651 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3652 torch::Tensor b = torch::rand(
3653 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3654 torch::Tensor c = torch::dot(a, b);
3655 ForEachDevice([&](const torch::Device& device) {
3656 torch::Tensor lazy_a = CopyToDevice(a, device);
3657 torch::Tensor lazy_b = CopyToDevice(b, device);
3658 torch::Tensor lazy_c = torch::dot(lazy_a, lazy_b);
3659 AllClose(c, lazy_c);
3660 });
3661}
3662
3663TEST_F(LazyOpsTest, TestTensorDot) {
3664 torch::Tensor a = torch::rand(
3665 {6, 4, 8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3666 torch::Tensor b = torch::rand(
3667 {4, 7, 8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3668 std::vector<int64_t> dims_a = {1, 2};
3669 std::vector<int64_t> dims_b = {0, 2};
3670 torch::Tensor c = torch::tensordot(a, b, dims_a, dims_b);
3671 ForEachDevice([&](const torch::Device& device) {
3672 torch::Tensor lazy_a = CopyToDevice(a, device);
3673 torch::Tensor lazy_b = CopyToDevice(b, device);
3674 torch::Tensor lazy_c = torch::tensordot(lazy_a, lazy_b, dims_a, dims_b);
3675 AllClose(c, lazy_c);
3676 });
3677}
3678
3679TEST_F(LazyOpsTest, TestGer) {
3680 torch::Tensor a = torch::rand(
3681 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3682 torch::Tensor b = torch::rand(
3683 {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3684 torch::Tensor c = torch::ger(a, b);
3685 ForEachDevice([&](const torch::Device& device) {
3686 torch::Tensor lazy_a = CopyToDevice(a, device);
3687 torch::Tensor lazy_b = CopyToDevice(b, device);
3688 torch::Tensor lazy_c = torch::ger(lazy_a, lazy_b);
3689 AllClose(c, lazy_c);
3690 });
3691}
3692
3693TEST_F(LazyOpsTest, TestMv) {
3694 torch::Tensor a = torch::rand(
3695 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3696 torch::Tensor b = torch::rand(
3697 {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3698 torch::Tensor c = torch::mv(a, b);
3699 ForEachDevice([&](const torch::Device& device) {
3700 torch::Tensor lazy_a = CopyToDevice(a, device);
3701 torch::Tensor lazy_b = CopyToDevice(b, device);
3702 torch::Tensor lazy_c = torch::mv(lazy_a, lazy_b);
3703 AllClose(c, lazy_c);
3704 });
3705}
3706
3707TEST_F(LazyOpsTest, TestMvOut) {
3708 torch::Tensor a = torch::rand(
3709 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3710 torch::Tensor b = torch::rand(
3711 {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3712 torch::Tensor c = torch::empty(
3713 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3714 torch::mv_out(c, a, b);
3715 ForEachDevice([&](const torch::Device& device) {
3716 torch::Tensor lazy_a = CopyToDevice(a, device);
3717 torch::Tensor lazy_b = CopyToDevice(b, device);
3718 torch::Tensor lazy_c = torch::empty({4}, lazy_b.options());
3719 torch::mv_out(lazy_c, lazy_a, lazy_b);
3720 AllClose(c, lazy_c);
3721 });
3722}
3723
3724TEST_F(LazyOpsTest, TestBatchAddBatchMatMul) {
3725 torch::Tensor a = torch::rand(
3726 {3, 6, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3727 torch::Tensor b = torch::rand(
3728 {3, 6, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3729 torch::Tensor c = torch::rand(
3730 {3, 4, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3731 torch::Scalar alpha = 0.5;
3732 torch::Scalar beta = 1.5;
3733 torch::Tensor d = torch::baddbmm(a, b, c, beta, alpha);
3734 ForEachDevice([&](const torch::Device& device) {
3735 torch::Tensor lazy_a = CopyToDevice(a, device);
3736 torch::Tensor lazy_b = CopyToDevice(b, device);
3737 torch::Tensor lazy_c = CopyToDevice(c, device);
3738 torch::Tensor lazy_d = torch::baddbmm(lazy_a, lazy_b, lazy_c, beta, alpha);
3739 AllClose(d, lazy_d, /*rtol=*/1e-3, /*atol=*/1e-4);
3740 });
3741}
3742
3743TEST_F(LazyOpsTest, TestBatchAddBatchMatMulInPlace) {
3744 torch::Tensor a = torch::rand(
3745 {3, 6, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3746 torch::Tensor b = torch::rand(
3747 {3, 6, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3748 torch::Tensor c = torch::rand(
3749 {3, 4, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3750 torch::Scalar alpha = 0.5;
3751 torch::Scalar beta = 1.5;
3752 ForEachDevice([&](const torch::Device& device) {
3753 torch::Tensor lazy_a = CopyToDevice(a, device);
3754 torch::Tensor lazy_b = CopyToDevice(b, device);
3755 torch::Tensor lazy_c = CopyToDevice(c, device);
3756 torch::Tensor d = a.baddbmm_(b, c, beta, alpha);
3757 torch::Tensor lazy_d = lazy_a.baddbmm_(lazy_b, lazy_c, beta, alpha);
3758 AllClose(d, lazy_d, /*rtol=*/1e-3, /*atol=*/1e-4);
3759 AllClose(a, lazy_a, /*rtol=*/1e-3, /*atol=*/1e-4);
3760 });
3761}
3762
3763TEST_F(LazyOpsTest, TestBatchMatMul) {
3764 torch::Tensor a = torch::rand(
3765 {3, 6, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3766 torch::Tensor b = torch::rand(
3767 {3, 4, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3768 torch::Tensor c = torch::bmm(a, b);
3769 ForEachDevice([&](const torch::Device& device) {
3770 torch::Tensor lazy_a = CopyToDevice(a, device);
3771 torch::Tensor lazy_b = CopyToDevice(b, device);
3772 torch::Tensor lazy_c = torch::bmm(lazy_a, lazy_b);
3773 AllClose(c, lazy_c, /*rtol=*/1e-3, /*atol=*/1e-4);
3774 });
3775}
3776
3777TEST_F(LazyOpsTest, TestChainMatMul) {
3778 torch::Tensor a = torch::rand(
3779 {5, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3780 torch::Tensor b = torch::rand(
3781 {4, 6}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3782 torch::Tensor c = torch::rand(
3783 {6, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3784 torch::Tensor d = torch::rand(
3785 {2, 7}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3786 torch::Tensor result = torch::chain_matmul({a, b, c, d});
3787 ForEachDevice([&](const torch::Device& device) {
3788 torch::Tensor lazy_a = CopyToDevice(a, device);
3789 torch::Tensor lazy_b = CopyToDevice(b, device);
3790 torch::Tensor lazy_c = CopyToDevice(c, device);
3791 torch::Tensor lazy_d = CopyToDevice(d, device);
3792 torch::Tensor lazy_result =
3793 torch::chain_matmul({lazy_a, lazy_b, lazy_c, lazy_d});
3794 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-4);
3795 });
3796}
3797
3798TEST_F(LazyOpsTest, TestLinear) {
3799 torch::Tensor input = torch::rand(
3800 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3801 torch::Tensor weight = torch::rand(
3802 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3803 torch::Tensor bias = torch::rand(
3804 {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3805 torch::Tensor result = torch::linear(input, weight);
3806 torch::Tensor result_with_bias = torch::linear(input, weight, bias);
3807 ForEachDevice([&](const torch::Device& device) {
3808 torch::Tensor lazy_input = CopyToDevice(input, device);
3809 torch::Tensor lazy_weight = CopyToDevice(weight, device);
3810 torch::Tensor lazy_bias = CopyToDevice(bias, device);
3811 torch::Tensor lazy_result = torch::linear(lazy_input, lazy_weight);
3812 torch::Tensor lazy_result_with_bias =
3813 torch::linear(lazy_input, lazy_weight, lazy_bias);
3814 AllClose(result, lazy_result, /*rtol=*/1e-2, /*atol=*/1e-4);
3815 AllClose(
3816 result_with_bias,
3817 lazy_result_with_bias,
3818 /*rtol=*/1e-2,
3819 /*atol=*/1e-4);
3820 });
3821}
3822
3823TEST_F(LazyOpsTest, TestPinverse) {
3824 torch::Tensor input = torch::rand(
3825 {4, 6}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3826 torch::Tensor result = torch::pinverse(input);
3827 ForEachDevice([&](const torch::Device& device) {
3828 torch::Tensor lazy_input = CopyToDevice(input, device);
3829 torch::Tensor lazy_result = torch::pinverse(lazy_input);
3830 AllClose(result, lazy_result, /*rtol=*/1e-4);
3831 });
3832}
3833
3834TEST_F(LazyOpsTest, TestEinsumOuter) {
3835 torch::Tensor a = torch::rand(
3836 {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3837 torch::Tensor b = torch::rand(
3838 {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3839 std::string equation = "i,j->ij";
3840 torch::Tensor c = torch::einsum(equation, {a, b});
3841 ForEachDevice([&](const torch::Device& device) {
3842 torch::Tensor lazy_a = CopyToDevice(a, device);
3843 torch::Tensor lazy_b = CopyToDevice(b, device);
3844 torch::Tensor lazy_c = torch::einsum(equation, {lazy_a, lazy_b});
3845 AllClose(c, lazy_c);
3846 });
3847}
3848
3849TEST_F(LazyOpsTest, TestEinsumOuterBackward) {
3850 torch::Tensor a = torch::rand(
3851 {5},
3852 torch::TensorOptions(torch::kFloat)
3853 .device(DefaultDevice())
3854 .requires_grad(true));
3855 torch::Tensor b = torch::rand(
3856 {5},
3857 torch::TensorOptions(torch::kFloat)
3858 .device(DefaultDevice())
3859 .requires_grad(true));
3860 std::string equation = "i,j->ij";
3861 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
3862 return torch::einsum(equation, inputs);
3863 };
3864 ForEachDevice([&](const torch::Device& device) {
3865 TestBackward({a, b}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
3866 });
3867}
3868
3869TEST_F(LazyOpsTest, TestEinsumBatchMatMul) {
3870 torch::Tensor a = torch::rand(
3871 {3, 2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3872 torch::Tensor b = torch::rand(
3873 {3, 5, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3874 std::string equation = "bij,bjk->bik";
3875 torch::Tensor c = torch::einsum(equation, {a, b});
3876 ForEachDevice([&](const torch::Device& device) {
3877 torch::Tensor lazy_a = CopyToDevice(a, device);
3878 torch::Tensor lazy_b = CopyToDevice(b, device);
3879 torch::Tensor lazy_c = torch::einsum(equation, {lazy_a, lazy_b});
3880 AllClose(c, lazy_c);
3881 });
3882}
3883
3884TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBilinear) {
3885 torch::Tensor a = torch::rand(
3886 {3, 5, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3887 torch::Tensor l = torch::rand(
3888 {2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3889 torch::Tensor r = torch::rand(
3890 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3891 std::string equation = "bn,anm,bm->ba";
3892 torch::Tensor c = torch::einsum(equation, {l, a, r});
3893 ForEachDevice([&](const torch::Device& device) {
3894 torch::Tensor lazy_l = CopyToDevice(l, device);
3895 torch::Tensor lazy_a = CopyToDevice(a, device);
3896 torch::Tensor lazy_r = CopyToDevice(r, device);
3897 torch::Tensor lazy_c = torch::einsum(equation, {lazy_l, lazy_a, lazy_r});
3898 AllClose(c, lazy_c);
3899 });
3900}
3901
3902TEST_F(LazyOpsTest, TestEinsumPyTorchLowerDiagonal) {
3903 torch::Tensor input = torch::rand(
3904 {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3905 std::string equation = "ii->i";
3906 torch::Tensor result = torch::einsum(equation, {input});
3907 ForEachDevice([&](const torch::Device& device) {
3908 torch::Tensor lazy_input = CopyToDevice(input, device);
3909 torch::Tensor lazy_result = torch::einsum(equation, {lazy_input});
3910 AllClose(result, lazy_result);
3911 });
3912}
3913
3914TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBatchDiagonal) {
3915 torch::Tensor input = torch::rand(
3916 {4, 3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3917 std::string equation = "...ii->...i";
3918 torch::Tensor result = torch::einsum(equation, {input});
3919 ForEachDevice([&](const torch::Device& device) {
3920 torch::Tensor lazy_input = CopyToDevice(input, device);
3921 torch::Tensor lazy_result = torch::einsum(equation, {lazy_input});
3922 AllClose(result, lazy_result);
3923 });
3924}
3925
3926TEST_F(LazyOpsTest, TestEinsumPyTorchLowerBatchPermute) {
3927 torch::Tensor input = torch::rand(
3928 {2, 3, 4, 5},
3929 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3930 std::string equation = "...ij->...ji";
3931 torch::Tensor result = torch::einsum(equation, {input});
3932 ForEachDevice([&](const torch::Device& device) {
3933 torch::Tensor lazy_input = CopyToDevice(input, device);
3934 torch::Tensor lazy_result = torch::einsum(equation, {lazy_input});
3935 AllClose(result, lazy_result);
3936 });
3937}
3938
3939TEST_F(LazyOpsTest, TestEinsumPyTorchLowerRepeatedAxis) {
3940 torch::Tensor x = torch::rand(
3941 {2, 3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3942 torch::Tensor y = torch::rand(
3943 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3944 std::string equation = "ijj,k->ik";
3945 torch::Tensor result = torch::einsum(equation, {x, y});
3946 ForEachDevice([&](const torch::Device& device) {
3947 torch::Tensor lazy_x = CopyToDevice(x, device);
3948 torch::Tensor lazy_y = CopyToDevice(y, device);
3949 torch::Tensor lazy_result = torch::einsum(equation, {lazy_x, lazy_y});
3950 AllClose(result, lazy_result);
3951 });
3952}
3953
3954TEST_F(LazyOpsTest, TestBilinear) {
3955 int batch_size = 16;
3956 int in1_features = 4;
3957 int in2_features = 6;
3958 int out_features = 8;
3959 torch::Tensor input1 = torch::rand(
3960 {batch_size, in1_features},
3961 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3962 torch::Tensor input2 = torch::rand(
3963 {batch_size, in2_features},
3964 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3965 torch::Tensor weight = torch::rand(
3966 {out_features, in1_features, in2_features},
3967 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3968 torch::Tensor bias = torch::rand(
3969 {out_features},
3970 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3971 ForEachDevice([&](const torch::Device& device) {
3972 torch::Tensor lazy_input1 = CopyToDevice(input1, device);
3973 torch::Tensor lazy_input2 = CopyToDevice(input2, device);
3974 torch::Tensor lazy_weight = CopyToDevice(weight, device);
3975 torch::Tensor lazy_bias = CopyToDevice(bias, device);
3976 torch::Tensor result = torch::bilinear(input1, input2, weight, bias);
3977 torch::Tensor lazy_result =
3978 torch::bilinear(lazy_input1, lazy_input2, lazy_weight, lazy_bias);
3979 AllClose(result, lazy_result);
3980 });
3981}
3982
3983TEST_F(LazyOpsTest, TestUpsampleNearest2D) {
3984 int batch_size = 2;
3985 int h = 5;
3986 int w = 5;
3987 int uh = 8;
3988 int uw = 8;
3989 int chans = 2;
3990 torch::Tensor input = torch::rand(
3991 {batch_size, chans, h, w},
3992 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
3993 ForEachDevice([&](const torch::Device& device) {
3994 torch::Tensor lazy_input = CopyToDevice(input, device);
3995 torch::Tensor result = torch::upsample_nearest2d(input, {uh, uw});
3996 torch::Tensor lazy_result = torch::upsample_nearest2d(lazy_input, {uh, uw});
3997 AllClose(result, lazy_result);
3998 });
3999}
4000
4001TEST_F(LazyOpsTest, TestUpsampleNearest2DBackward) {
4002 int batch_size = 2;
4003 int h = 5;
4004 int w = 5;
4005 int uh = 8;
4006 int uw = 8;
4007 int chans = 2;
4008 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4009 return torch::upsample_nearest2d(inputs[0], {uh, uw});
4010 };
4011 ForEachDevice([&](const torch::Device& device) {
4012 TestBackward(
4013 {torch::rand(
4014 {batch_size, chans, h, w},
4015 torch::TensorOptions(torch::kFloat)
4016 .device(DefaultDevice())
4017 .requires_grad(true))},
4018 device,
4019 testfn);
4020 });
4021}
4022
4023TEST_F(LazyOpsTest, TestUpsampleNearest2DWithScale) {
4024 int batch_size = 2;
4025 int h = 5;
4026 int w = 5;
4027 int chans = 2;
4028 double scale_h = 2.5;
4029 double scale_w = 3.4;
4030 torch::Tensor input = torch::rand(
4031 {batch_size, chans, h, w},
4032 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4033 ForEachDevice([&](const torch::Device& device) {
4034 torch::Tensor lazy_input = CopyToDevice(input, device);
4035 torch::Tensor result = torch::upsample_nearest2d(
4036 input, c10::nullopt, at::ArrayRef<double>{scale_h, scale_w});
4037 torch::Tensor lazy_result = torch::upsample_nearest2d(
4038 lazy_input, c10::nullopt, at::ArrayRef<double>{scale_h, scale_w});
4039 AllClose(result, lazy_result);
4040 });
4041}
4042
4043TEST_F(LazyOpsTest, TestUpsampleNearest2DBackwardWithScale) {
4044 int batch_size = 2;
4045 int h = 5;
4046 int w = 5;
4047 int chans = 2;
4048 double scale_h = 2.5;
4049 double scale_w = 3.4;
4050 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4051 return torch::upsample_nearest2d(
4052 inputs[0], c10::nullopt, at::ArrayRef<double>{scale_h, scale_w});
4053 };
4054 ForEachDevice([&](const torch::Device& device) {
4055 TestBackward(
4056 {torch::rand(
4057 {batch_size, chans, h, w},
4058 torch::TensorOptions(torch::kFloat)
4059 .device(DefaultDevice())
4060 .requires_grad(true))},
4061 device,
4062 testfn);
4063 });
4064}
4065
4066TEST_F(LazyOpsTest, TestUpsampleBilinear2D) {
4067 int batch_size = 2;
4068 int h = 5;
4069 int w = 5;
4070 int uh = 8;
4071 int uw = 8;
4072 int chans = 2;
4073 for (bool align_corners : {true, false}) {
4074 torch::Tensor input = torch::rand(
4075 {batch_size, chans, h, w},
4076 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4077 ForEachDevice([&](const torch::Device& device) {
4078 torch::Tensor lazy_input = CopyToDevice(input, device);
4079 torch::Tensor result =
4080 torch::upsample_bilinear2d(input, {uh, uw}, align_corners);
4081 torch::Tensor lazy_result =
4082 torch::upsample_bilinear2d(lazy_input, {uh, uw}, align_corners);
4083 AllClose(result, lazy_result);
4084 });
4085 }
4086}
4087
4088TEST_F(LazyOpsTest, TestUpsampleBilinear2DBackward) {
4089 int batch_size = 2;
4090 int h = 5;
4091 int w = 5;
4092 int uh = 8;
4093 int uw = 8;
4094 int chans = 2;
4095 for (bool align_corners : {true, false}) {
4096 auto testfn =
4097 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4098 return torch::upsample_bilinear2d(inputs[0], {uh, uw}, align_corners);
4099 };
4100 ForEachDevice([&](const torch::Device& device) {
4101 TestBackward(
4102 {torch::rand(
4103 {batch_size, chans, h, w},
4104 torch::TensorOptions(torch::kFloat)
4105 .device(DefaultDevice())
4106 .requires_grad(true))},
4107 device,
4108 testfn);
4109 });
4110 }
4111}
4112
4113TEST_F(LazyOpsTest, TestAddCMul) {
4114 torch::Tensor a = torch::rand(
4115 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4116 torch::Tensor b = torch::rand(
4117 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4118 torch::Tensor c = torch::rand(
4119 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4120 torch::Tensor d = torch::addcmul(a, b, c, 3.1165);
4121 ForEachDevice([&](const torch::Device& device) {
4122 torch::Tensor lazy_a = CopyToDevice(a, device);
4123 torch::Tensor lazy_b = CopyToDevice(b, device);
4124 torch::Tensor lazy_c = CopyToDevice(c, device);
4125 torch::Tensor lazy_d = torch::addcmul(lazy_a, lazy_b, lazy_c, 3.1165);
4126 AllClose(d, lazy_d);
4127 });
4128}
4129
4130TEST_F(LazyOpsTest, TestAddCDiv) {
4131 torch::Tensor a = torch::rand(
4132 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4133 torch::Tensor b = torch::rand(
4134 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4135 torch::Tensor c =
4136 torch::abs(torch::rand(
4137 {2, 2},
4138 torch::TensorOptions(torch::kFloat).device(DefaultDevice()))) +
4139 1.0;
4140 torch::Tensor d = torch::addcdiv(a, b, c, 3.1165);
4141 ForEachDevice([&](const torch::Device& device) {
4142 torch::Tensor lazy_a = CopyToDevice(a, device);
4143 torch::Tensor lazy_b = CopyToDevice(b, device);
4144 torch::Tensor lazy_c = CopyToDevice(c, device);
4145 torch::Tensor lazy_d = torch::addcdiv(lazy_a, lazy_b, lazy_c, 3.1165);
4146 AllClose(d, lazy_d);
4147 });
4148}
4149
4150TEST_F(LazyOpsTest, TestAddCDivWithBroadcast) {
4151 torch::Tensor a = torch::rand(
4152 {1, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4153 torch::Tensor b = torch::rand(
4154 {3, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4155 torch::Tensor c =
4156 torch::abs(torch::rand(
4157 {1, 3},
4158 torch::TensorOptions(torch::kFloat).device(DefaultDevice()))) +
4159 1.0;
4160 torch::Tensor d = torch::addcdiv(a, b, c, 3.1165);
4161 ForEachDevice([&](const torch::Device& device) {
4162 torch::Tensor lazy_a = CopyToDevice(a, device);
4163 torch::Tensor lazy_b = CopyToDevice(b, device);
4164 torch::Tensor lazy_c = CopyToDevice(c, device);
4165 torch::Tensor lazy_d = torch::addcdiv(lazy_a, lazy_b, lazy_c, 3.1165);
4166 AllClose(d, lazy_d);
4167 });
4168}
4169
4170TEST_F(LazyOpsTest, TestSize) {
4171 torch::Tensor input = torch::rand(
4172 {2, 1, 4, 6},
4173 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4174 int rank = input.dim();
4175 ForEachDevice([&](const torch::Device& device) {
4176 torch::Tensor lazy_input = CopyToDevice(input, device);
4177 for (int dim = -rank; dim < rank; ++dim) {
4178 EXPECT_EQ(torch::size(input, dim), torch::size(lazy_input, dim));
4179 }
4180 });
4181}
4182
4183TEST_F(LazyOpsTest, TestSelect) {
4184 std::vector<int64_t> input_sizes = {14, 24, 8};
4185 int rank = input_sizes.size();
4186 for (int dim = -rank; dim < rank; ++dim) {
4187 auto testfn =
4188 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4189 return torch::select(inputs[0], dim, 0);
4190 };
4191 ForEachDevice([&](const torch::Device& device) {
4192 TestBackward(
4193 {torch::rand(
4194 input_sizes,
4195 torch::TensorOptions(torch::kFloat).requires_grad(true))},
4196 device,
4197 testfn);
4198 });
4199 };
4200}
4201
4202TEST_F(LazyOpsTest, TestBernoulliScalarProb) {
4203 torch::Tensor input = torch::zeros(
4204 1000, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4205 ForEachDevice([&](const torch::Device& device) {
4206 torch::Tensor lazy_input = CopyToDevice(input, device);
4207 torch::Tensor lazy_output = torch::bernoulli(lazy_input, 0.1);
4208 double frac = lazy_output.sum().item().toDouble() / input.numel();
4209 EXPECT_GT(frac, 0.06);
4210 EXPECT_LT(frac, 0.14);
4211 });
4212}
4213
4214TEST_F(LazyOpsTest, TestBernoulliTensorProb) {
4215 std::vector<float> prob_values(1000, 0.1);
4216 torch::Tensor input = torch::tensor(
4217 prob_values, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4218 ForEachDevice([&](const torch::Device& device) {
4219 torch::Tensor lazy_input = CopyToDevice(input, device);
4220 torch::Tensor lazy_output = torch::bernoulli(lazy_input);
4221 double frac = lazy_output.sum().item().toDouble() / input.numel();
4222 EXPECT_GT(frac, 0.06);
4223 EXPECT_LT(frac, 0.14);
4224 });
4225}
4226
4227TEST_F(LazyOpsTest, TestBernoulliScalarProbInPlace) {
4228 torch::Tensor input = torch::zeros(
4229 1000, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4230 ForEachDevice([&](const torch::Device& device) {
4231 torch::Tensor lazy_input = CopyToDevice(input, device);
4232 lazy_input.bernoulli_(0.1);
4233 double frac = lazy_input.sum().item().toDouble() / input.numel();
4234 EXPECT_GT(frac, 0.06);
4235 EXPECT_LT(frac, 0.14);
4236 });
4237}
4238
4239TEST_F(LazyOpsTest, TestBernoulliTensorProbInPlace) {
4240 torch::Tensor input = torch::zeros(
4241 1000, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4242 torch::Tensor prob = torch::scalar_tensor(
4243 0.1, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4244 ForEachDevice([&](const torch::Device& device) {
4245 torch::Tensor lazy_input = CopyToDevice(input, device);
4246 torch::Tensor lazy_prob = CopyToDevice(prob, device);
4247 lazy_input.bernoulli_(lazy_prob);
4248 double frac = lazy_input.sum().item().toDouble() / input.numel();
4249 EXPECT_GT(frac, 0.06);
4250 EXPECT_LT(frac, 0.14);
4251 });
4252}
4253
4254TEST_F(LazyOpsTest, TestDropout) {
4255 torch::Tensor a = torch::rand(
4256 {17, 21}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4257 ForEachDevice([&](const torch::Device& device) {
4258 torch::Tensor lazy_a = CopyToDevice(a, device);
4259 torch::Tensor lazy_b = torch::dropout(lazy_a, 0.1, /*train=*/true);
4260 double prob =
4261 static_cast<double>(lazy_b.cpu().ne(0.0f).sum().item().toDouble()) /
4262 a.numel();
4263 EXPECT_GT(prob, 0.86);
4264 EXPECT_LT(prob, 0.94);
4265 });
4266}
4267
4268TEST_F(LazyOpsTest, TestDropoutInPlace) {
4269 torch::Tensor a = torch::rand(
4270 {17, 21}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4271 ForEachDevice([&](const torch::Device& device) {
4272 torch::Tensor lazy_a = CopyToDevice(a, device);
4273 torch::dropout_(lazy_a, 0.1, /*train=*/true);
4274 double prob =
4275 static_cast<double>(lazy_a.cpu().ne(0.0f).sum().item().toDouble()) /
4276 a.numel();
4277 EXPECT_GT(prob, 0.85);
4278 EXPECT_LT(prob, 0.94);
4279 });
4280}
4281
4282TEST_F(LazyOpsTest, TestRandperm) {
4283 unsigned n = 5;
4284 torch::Tensor shuffle = torch::randperm(
4285 n, torch::TensorOptions(torch::kLong).device(torch::kLazy));
4286 torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);
4287 std::vector<int64_t> shuffle_data(
4288 shuffle_cpu.data_ptr<int64_t>(), shuffle_cpu.data_ptr<int64_t>() + n);
4289 EXPECT_TRUE(
4290 shuffle_data.size() == n && torch::lazy::IsPermutation(shuffle_data));
4291}
4292
4293TEST_F(LazyOpsTest, TestSlice) {
4294 torch::Tensor a = torch::rand(
4295 {32, 24, 16},
4296 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4297 torch::Tensor b = torch::slice(a, 1, 0, 16, 1);
4298 ForEachDevice([&](const torch::Device& device) {
4299 torch::Tensor lazy_a = CopyToDevice(a, device);
4300 torch::Tensor lazy_b = torch::slice(lazy_a, 1, 0, 16, 1);
4301 AllClose(b, lazy_b);
4302 });
4303}
4304
4305TEST_F(LazyOpsTest, TestTake) {
4306 torch::Tensor a = torch::rand(
4307 {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4308 torch::Tensor b = torch::randint(
4309 16, {5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4310 torch::Tensor c = torch::take(a, b);
4311 ForEachDevice([&](const torch::Device& device) {
4312 torch::Tensor lazy_a = CopyToDevice(a, device);
4313 torch::Tensor lazy_b = CopyToDevice(b, device);
4314 torch::Tensor lazy_c = torch::take(lazy_a, lazy_b);
4315 AllClose(c, lazy_c);
4316 });
4317}
4318
4319TEST_F(LazyOpsTest, TestTakeBackward) {
4320 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4321 return torch::take(inputs[0], inputs[1]);
4322 };
4323 ForEachDevice([&](const torch::Device& device) {
4324 TestBackward(
4325 {torch::rand(
4326 {4, 4},
4327 torch::TensorOptions(torch::kFloat)
4328 .device(DefaultDevice())
4329 .requires_grad(true)),
4330 torch::randint(
4331 16,
4332 {5},
4333 torch::TensorOptions(torch::kLong).device(DefaultDevice()))},
4334 device,
4335 testfn);
4336 });
4337}
4338
4339TEST_F(LazyOpsTest, TestStack) {
4340 torch::Tensor a = torch::rand(
4341 {2, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4342 torch::Tensor b = torch::rand(
4343 {2, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4344 torch::Tensor c = torch::rand(
4345 {2, 4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4346 int rank = a.dim() + 1;
4347 for (int dim = -rank; dim < rank; ++dim) {
4348 torch::Tensor d = torch::stack({a, b, c}, dim);
4349 ForEachDevice([&](const torch::Device& device) {
4350 torch::Tensor lazy_a = CopyToDevice(a, device);
4351 torch::Tensor lazy_b = CopyToDevice(b, device);
4352 torch::Tensor lazy_c = CopyToDevice(c, device);
4353 torch::Tensor lazy_d = torch::stack({lazy_a, lazy_b, lazy_c}, dim);
4354 AllClose(d, lazy_d);
4355 });
4356 }
4357}
4358
4359TEST_F(LazyOpsTest, TestCat) {
4360 torch::Tensor a = torch::rand(
4361 {2, 1, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4362 torch::Tensor b = torch::rand(
4363 {2, 2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4364 torch::Tensor c = torch::rand(
4365 {2, 3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4366 for (int dim : {1, -2}) {
4367 torch::Tensor d = torch::cat({a, b, c}, dim);
4368 ForEachDevice([&](const torch::Device& device) {
4369 torch::Tensor lazy_a = CopyToDevice(a, device);
4370 torch::Tensor lazy_b = CopyToDevice(b, device);
4371 torch::Tensor lazy_c = CopyToDevice(c, device);
4372 torch::Tensor lazy_d = torch::cat({lazy_a, lazy_b, lazy_c}, dim);
4373 EXPECT_TRUE(d.sizes() == lazy_d.sizes() && d.dtype() == lazy_d.dtype());
4374 AllClose(d, lazy_d);
4375 });
4376 }
4377}
4378
4379TEST_F(LazyOpsTest, TestUnbind) {
4380 torch::Tensor input = torch::rand(
4381 {4, 3, 7}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4382 int rank = input.dim();
4383 for (int dim = -rank; dim < rank; ++dim) {
4384 std::vector<torch::Tensor> output = torch::unbind(input, dim);
4385 ForEachDevice([&](const torch::Device& device) {
4386 torch::Tensor lazy_input = CopyToDevice(input, device);
4387 std::vector<torch::Tensor> lazy_output = torch::unbind(lazy_input, dim);
4388 ASSERT_EQ(output.size(), lazy_output.size());
4389 for (size_t i = 0; i < output.size(); ++i) {
4390 AllClose(output[i], lazy_output[i]);
4391 }
4392 });
4393 }
4394}
4395
4396TEST_F(LazyOpsTest, TestRepeat) {
4397 std::vector<std::vector<int64_t>> repeats_list = {{4, 2}, {4, 2, 3}};
4398 std::vector<std::vector<int64_t>> input_size_list = {{3}, {2, 4}};
4399 for (const auto& repeats : repeats_list) {
4400 for (const auto& input_size : input_size_list) {
4401 torch::Tensor input = torch::rand(
4402 input_size,
4403 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4404 torch::Tensor output = input.repeat(repeats);
4405 ForEachDevice([&](const torch::Device& device) {
4406 torch::Tensor lazy_input = CopyToDevice(input, device);
4407 torch::Tensor lazy_output = lazy_input.repeat(repeats);
4408 AllClose(output, lazy_output);
4409 });
4410 }
4411 }
4412}
4413
4414TEST_F(LazyOpsTest, TestGather) {
4415 torch::Tensor a = torch::rand(
4416 {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4417 torch::Tensor b = torch::empty(
4418 {3, 3}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4419 for (int i = 0; i < 3; i++) {
4420 for (int j = 0; j < 3; j++) {
4421 b[i][j] = (i + j) % 3;
4422 }
4423 }
4424 for (bool sparse_grad : {false, true}) {
4425 torch::Tensor c = torch::gather(a, 1, b, sparse_grad);
4426 ForEachDevice([&](const torch::Device& device) {
4427 torch::Tensor lazy_a = CopyToDevice(a, device);
4428 torch::Tensor lazy_b = CopyToDevice(b, device);
4429 torch::Tensor lazy_c = torch::gather(lazy_a, 1, lazy_b, sparse_grad);
4430 AllClose(c, lazy_c);
4431 });
4432 }
4433}
4434
4435TEST_F(LazyOpsTest, TestScatter) {
4436 torch::Tensor a = torch::rand(
4437 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4438 torch::Tensor b = torch::rand(
4439 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4440 torch::Tensor c = torch::empty(
4441 {3, 5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4442 for (int dim = 0; dim < 2; ++dim) {
4443 for (int i = 0; i < 3; i++) {
4444 for (int j = 0; j < 5; j++) {
4445 c[i][j] = (i + j) % c.sizes()[dim];
4446 }
4447 }
4448 torch::Tensor d = torch::scatter(a, dim, c, b);
4449 ForEachDevice([&](const torch::Device& device) {
4450 torch::Tensor lazy_a = CopyToDevice(a, device);
4451 torch::Tensor lazy_b = CopyToDevice(b, device);
4452 torch::Tensor lazy_c = CopyToDevice(c, device);
4453 torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, lazy_b);
4454 AllClose(d, lazy_d);
4455 });
4456 }
4457}
4458
4459TEST_F(LazyOpsTest, TestScatterR1) {
4460 torch::Tensor a = torch::rand(
4461 {5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4462 torch::Tensor b = torch::rand(
4463 {2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4464 torch::Tensor c = torch::empty(
4465 {2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4466 c[0] = 1;
4467 c[1] = 3;
4468 torch::Tensor d = torch::scatter(a, 0, c, b);
4469 ForEachDevice([&](const torch::Device& device) {
4470 torch::Tensor lazy_a = CopyToDevice(a, device);
4471 torch::Tensor lazy_b = CopyToDevice(b, device);
4472 torch::Tensor lazy_c = CopyToDevice(c, device);
4473 torch::Tensor lazy_d = torch::scatter(lazy_a, 0, lazy_c, lazy_b);
4474 AllClose(d, lazy_d);
4475 });
4476}
4477
4478TEST_F(LazyOpsTest, TestScatterR3) {
4479 torch::Tensor a = torch::rand(
4480 {3, 5, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4481 torch::Tensor b = torch::rand(
4482 {3, 4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4483 torch::Tensor c = torch::empty(
4484 {3, 4, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4485 for (int i = 0; i < 3; i++) {
4486 for (int j = 0; j < 4; j++) {
4487 for (int k = 0; k < 2; k++) {
4488 c[i][j][k] = (i + j + k) % 4;
4489 }
4490 }
4491 }
4492 torch::Tensor d = torch::scatter(a, 1, c, b);
4493 ForEachDevice([&](const torch::Device& device) {
4494 torch::Tensor lazy_a = CopyToDevice(a, device);
4495 torch::Tensor lazy_b = CopyToDevice(b, device);
4496 torch::Tensor lazy_c = CopyToDevice(c, device);
4497 torch::Tensor lazy_d = torch::scatter(lazy_a, 1, lazy_c, lazy_b);
4498 AllClose(d, lazy_d);
4499 });
4500}
4501
4502TEST_F(LazyOpsTest, TestScatterBiggerSource) {
4503 torch::Tensor a = torch::rand(
4504 {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4505 torch::Tensor b = torch::rand(
4506 {8, 8}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4507 torch::Tensor c = torch::empty(
4508 {4, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4509 for (int i = 0; i < 4; i++) {
4510 for (int j = 0; j < 4; j++) {
4511 c[i][j] = (i + j) % 4;
4512 }
4513 }
4514 for (int dim = 0; dim < 2; ++dim) {
4515 torch::Tensor d = torch::scatter(a, dim, c, b);
4516 ForEachDevice([&](const torch::Device& device) {
4517 torch::Tensor lazy_a = CopyToDevice(a, device);
4518 torch::Tensor lazy_b = CopyToDevice(b, device);
4519 torch::Tensor lazy_c = CopyToDevice(c, device);
4520 torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, lazy_b);
4521 AllClose(d, lazy_d);
4522 });
4523 }
4524}
4525
4526TEST_F(LazyOpsTest, TestScatterScalar) {
4527 torch::Tensor a = torch::rand(
4528 {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4529 torch::Scalar b = 1.0f;
4530 torch::Tensor c = torch::empty(
4531 {4, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4532 for (int i = 0; i < 4; i++) {
4533 for (int j = 0; j < 4; j++) {
4534 c[i][j] = (i + j) % 4;
4535 }
4536 }
4537 for (int dim = 0; dim < 2; ++dim) {
4538 torch::Tensor d = torch::scatter(a, dim, c, b);
4539 ForEachDevice([&](const torch::Device& device) {
4540 torch::Tensor lazy_a = CopyToDevice(a, device);
4541 torch::Tensor lazy_c = CopyToDevice(c, device);
4542 torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, b);
4543 AllClose(d, lazy_d);
4544 });
4545 }
4546}
4547
4548TEST_F(LazyOpsTest, TestScatterReduceAdd) {
4549 torch::Tensor a = torch::rand(
4550 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4551 torch::Tensor b = torch::rand(
4552 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4553 torch::Tensor c = torch::empty(
4554 {3, 5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4555 for (int dim = 0; dim < 2; ++dim) {
4556 for (int i = 0; i < 3; i++) {
4557 for (int j = 0; j < 5; j++) {
4558 c[i][j] = (i + j) % c.sizes()[dim];
4559 }
4560 }
4561 torch::Tensor d = torch::scatter(a, dim, c, b, "add");
4562 ForEachDevice([&](const torch::Device& device) {
4563 torch::Tensor lazy_a = CopyToDevice(a, device);
4564 torch::Tensor lazy_b = CopyToDevice(b, device);
4565 torch::Tensor lazy_c = CopyToDevice(c, device);
4566 torch::Tensor lazy_d = torch::scatter(lazy_a, dim, lazy_c, lazy_b, "add");
4567 AllClose(d, lazy_d);
4568 });
4569 }
4570
4571 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4572 ExpectCounterChanged("lazy::scatter_out", GetIgnoredCounters());
4573}
4574
4575TEST_F(LazyOpsTest, TestScatterAdd) {
4576 torch::Tensor a = torch::rand(
4577 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4578 torch::Tensor b = torch::rand(
4579 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4580 torch::Tensor c = torch::empty(
4581 {3, 5}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4582 for (int dim = 0; dim < 2; ++dim) {
4583 for (int i = 0; i < 3; i++) {
4584 for (int j = 0; j < 5; j++) {
4585 c[i][j] = (i + j) % c.sizes()[dim];
4586 }
4587 }
4588 torch::Tensor d = torch::scatter_add(a, dim, c, b);
4589 ForEachDevice([&](const torch::Device& device) {
4590 torch::Tensor lazy_a = CopyToDevice(a, device);
4591 torch::Tensor lazy_b = CopyToDevice(b, device);
4592 torch::Tensor lazy_c = CopyToDevice(c, device);
4593 torch::Tensor lazy_d = torch::scatter_add(lazy_a, dim, lazy_c, lazy_b);
4594 AllClose(d, lazy_d);
4595 });
4596 }
4597}
4598
4599TEST_F(LazyOpsTest, TestScatterAddInPlace) {
4600 torch::Tensor b = torch::rand(
4601 {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4602 torch::Tensor c = torch::empty(
4603 {4, 4}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4604 for (int i = 0; i < 4; i++) {
4605 for (int j = 0; j < 4; j++) {
4606 c[i][j] = (i + j) % 4;
4607 }
4608 }
4609 for (int dim = 0; dim < 2; ++dim) {
4610 ForEachDevice([&](const torch::Device& device) {
4611 torch::Tensor a = torch::rand(
4612 {4, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4613 torch::Tensor lazy_a = CopyToDevice(a, device);
4614 torch::Tensor d = a.scatter_add_(dim, c, b);
4615 torch::Tensor lazy_b = CopyToDevice(b, device);
4616 torch::Tensor lazy_c = CopyToDevice(c, device);
4617 torch::Tensor lazy_d = lazy_a.scatter_add_(dim, lazy_c, lazy_b);
4618 AllClose(d, lazy_d);
4619 AllClose(a, lazy_a);
4620 });
4621 }
4622}
4623
4624TEST_F(LazyOpsTest, TestIndexSelect) {
4625 for (torch::ScalarType scalar_type :
4626 {torch::kFloat,
4627 torch::kByte,
4628 torch::kChar,
4629 torch::kShort,
4630 torch::kInt,
4631 torch::kLong}) {
4632 torch::Tensor a = isFloatingType(scalar_type)
4633 ? torch::rand(
4634 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
4635 : torch::randint(
4636 100,
4637 {3, 4},
4638 torch::TensorOptions(scalar_type).device(DefaultDevice()));
4639 for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) {
4640 torch::Tensor b = torch::empty(
4641 {2}, torch::TensorOptions(index_scalar_type).device(DefaultDevice()));
4642 b[0] = 0;
4643 b[1] = 2;
4644 for (auto offset : {-2, 0}) {
4645 torch::Tensor c0 = torch::index_select(a, 0 + offset, b);
4646 torch::Tensor c1 = torch::index_select(a, 1 + offset, b);
4647 ForEachDevice([&](const torch::Device& device) {
4648 torch::Tensor lazy_a = CopyToDevice(a, device);
4649 torch::Tensor lazy_b = CopyToDevice(b, device);
4650 torch::Tensor lazy_c0 =
4651 torch::index_select(lazy_a, 0 + offset, lazy_b);
4652 torch::Tensor lazy_c1 =
4653 torch::index_select(lazy_a, 1 + offset, lazy_b);
4654 AllEqual(c0, lazy_c0);
4655 AllEqual(c1, lazy_c1);
4656 });
4657 }
4658 }
4659 }
4660}
4661
4662TEST_F(LazyOpsTest, TestIndexSelectRank0) {
4663 for (torch::ScalarType scalar_type :
4664 {torch::kFloat,
4665 torch::kByte,
4666 torch::kChar,
4667 torch::kShort,
4668 torch::kInt,
4669 torch::kLong}) {
4670 torch::Tensor a = isFloatingType(scalar_type)
4671 ? torch::rand(
4672 {3, 4}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
4673 : torch::randint(
4674 100,
4675 {3, 4},
4676 torch::TensorOptions(scalar_type).device(DefaultDevice()));
4677 torch::Tensor b = torch::scalar_tensor(
4678 2, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4679 torch::Tensor c0 = torch::index_select(a, 0, b);
4680 torch::Tensor c1 = torch::index_select(a, 1, b);
4681 ForEachDevice([&](const torch::Device& device) {
4682 torch::Tensor lazy_a = CopyToDevice(a, device);
4683 torch::Tensor lazy_b = CopyToDevice(b, device);
4684 torch::Tensor lazy_c0 = torch::index_select(lazy_a, 0, lazy_b);
4685 torch::Tensor lazy_c1 = torch::index_select(lazy_a, 1, lazy_b);
4686 AllEqual(c0, lazy_c0);
4687 AllEqual(c1, lazy_c1);
4688 });
4689 }
4690}
4691
4692TEST_F(LazyOpsTest, TestInverse) {
4693 if (IsCuda()) {
4694 // TODO(whc) debug failure on cuda, lazy_b comes back transposed
4695 GTEST_SKIP();
4696 }
4697 torch::Tensor a = torch::randn(
4698 {5, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4699 torch::Tensor b = torch::inverse(a);
4700 ForEachDevice([&](const torch::Device& device) {
4701 torch::Tensor lazy_a = CopyToDevice(a, device);
4702 torch::Tensor lazy_b = torch::inverse(lazy_a);
4703 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-4);
4704 });
4705}
4706
4707TEST_F(LazyOpsTest, TestIsnan) {
4708 torch::Tensor a = torch::tensor(
4709 {1.0, 2.0, std::nan("1"), 4.0},
4710 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4711 torch::Tensor b = torch::isnan(a);
4712 ForEachDevice([&](const torch::Device& device) {
4713 torch::Tensor lazy_a = CopyToDevice(a, device);
4714 torch::Tensor lazy_b = torch::isnan(lazy_a);
4715 AllEqual(b, lazy_b);
4716 });
4717 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4718 ExpectCounterChanged("lazy::isnan", GetIgnoredCounters());
4719}
4720
4721TEST_F(LazyOpsTest, TestExpand) {
4722 torch::Tensor a = torch::rand(
4723 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4724 torch::Tensor b = a.expand({2, 3, 4}, /*implicit=*/false);
4725 ForEachDevice([&](const torch::Device& device) {
4726 torch::Tensor lazy_a = CopyToDevice(a, device);
4727 torch::Tensor lazy_b = lazy_a.expand({2, 3, 4}, /*implicit=*/false);
4728 AllClose(b, lazy_b);
4729 });
4730}
4731
4732TEST_F(LazyOpsTest, TestExpandBack) {
4733 torch::Tensor a = torch::rand(
4734 {3, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4735 torch::Tensor b = a.expand({3, 4}, /*implicit=*/false);
4736 ForEachDevice([&](const torch::Device& device) {
4737 torch::Tensor lazy_a = CopyToDevice(a, device);
4738 torch::Tensor lazy_b = lazy_a.expand({3, 4}, /*implicit=*/false);
4739 AllClose(b, lazy_b);
4740 });
4741}
4742
4743TEST_F(LazyOpsTest, TestExpandAs) {
4744 torch::Tensor a = torch::rand(
4745 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4746 torch::Tensor b = torch::rand(
4747 {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4748 torch::Tensor c = torch::native::expand_as(a, b);
4749 ForEachDevice([&](const torch::Device& device) {
4750 torch::Tensor lazy_a = CopyToDevice(a, device);
4751 torch::Tensor lazy_b = CopyToDevice(b, device);
4752 torch::Tensor lazy_c = torch::native::expand_as(lazy_a, lazy_b);
4753 AllClose(c, lazy_c);
4754 });
4755}
4756
4757TEST_F(LazyOpsTest, TestEye) {
4758 int n = 5;
4759 ForEachDevice([&](const torch::Device& device) {
4760 torch::Tensor out = torch::eye(
4761 n, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4762 torch::Tensor lazy_out =
4763 torch::eye(n, torch::TensorOptions(torch::kFloat).device(device));
4764 AllClose(out, lazy_out);
4765 });
4766}
4767
4768TEST_F(LazyOpsTest, TestEyeWide) {
4769 int lines = 3;
4770 int cols = 5;
4771 ForEachDevice([&](const torch::Device& device) {
4772 torch::Tensor out = torch::eye(
4773 lines,
4774 cols,
4775 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4776 torch::Tensor lazy_out = torch::eye(
4777 lines, cols, torch::TensorOptions(torch::kFloat).device(device));
4778 AllClose(out, lazy_out);
4779 });
4780}
4781
4782TEST_F(LazyOpsTest, TestEyeNarrow) {
4783 int lines = 5;
4784 int cols = 3;
4785 ForEachDevice([&](const torch::Device& device) {
4786 torch::Tensor out = torch::eye(
4787 lines,
4788 cols,
4789 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4790 torch::Tensor lazy_out = torch::eye(
4791 lines, cols, torch::TensorOptions(torch::kFloat).device(device));
4792 AllClose(out, lazy_out);
4793 });
4794}
4795
4796TEST_F(LazyOpsTest, TestBroadcastTensors) {
4797 torch::Tensor a = torch::rand(
4798 {2, 1, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4799 torch::Tensor b = torch::rand(
4800 {2, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4801 std::vector<torch::Tensor> c = torch::broadcast_tensors({a, b});
4802 ForEachDevice([&](const torch::Device& device) {
4803 torch::Tensor lazy_a = CopyToDevice(a, device);
4804 torch::Tensor lazy_b = CopyToDevice(b, device);
4805 std::vector<torch::Tensor> lazy_c =
4806 torch::broadcast_tensors({lazy_a, lazy_b});
4807 ASSERT_EQ(c.size(), lazy_c.size());
4808 for (size_t i = 0; i < c.size(); ++i) {
4809 AllClose(c[i], lazy_c[i]);
4810 }
4811 });
4812}
4813
4814TEST_F(LazyOpsTest, TestOneIndex) {
4815 for (torch::ScalarType scalar_type :
4816 {torch::kFloat,
4817 torch::kByte,
4818 torch::kChar,
4819 torch::kShort,
4820 torch::kInt,
4821 torch::kLong}) {
4822 torch::Tensor params = isFloatingType(scalar_type)
4823 ? torch::rand(
4824 {4, 3, 5, 6, 7},
4825 torch::TensorOptions(scalar_type).device(DefaultDevice()))
4826 : torch::randint(
4827 100,
4828 {4, 3, 5, 6, 7},
4829 torch::TensorOptions(scalar_type).device(DefaultDevice()));
4830 torch::Tensor indices = torch::randint(
4831 -3,
4832 3,
4833 {2, 4, 3},
4834 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4835 torch::Tensor result = torch::index(params, {indices});
4836 ForEachDevice([&](const torch::Device& device) {
4837 torch::Tensor lazy_params = CopyToDevice(params, device);
4838 torch::Tensor lazy_indices = CopyToDevice(indices, device);
4839 torch::Tensor lazy_result = torch::index(lazy_params, {lazy_indices});
4840 AllEqual(result, lazy_result);
4841 });
4842 }
4843}
4844
4845TEST_F(LazyOpsTest, TestOneIndexTransfer) {
4846 for (torch::ScalarType scalar_type :
4847 {torch::kFloat,
4848 torch::kByte,
4849 torch::kChar,
4850 torch::kShort,
4851 torch::kInt,
4852 torch::kLong}) {
4853 torch::Tensor params = isFloatingType(scalar_type)
4854 ? torch::rand(
4855 {4, 3, 5, 6, 7},
4856 torch::TensorOptions(scalar_type).device(DefaultDevice()))
4857 : torch::randint(
4858 100,
4859 {4, 3, 5, 6, 7},
4860 torch::TensorOptions(scalar_type).device(DefaultDevice()));
4861 torch::Tensor indices = torch::randint(
4862 -3,
4863 3,
4864 {2, 4, 3},
4865 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4866 torch::Tensor result = torch::index(params, {indices});
4867 ForEachDevice([&](const torch::Device& device) {
4868 torch::Tensor lazy_params = CopyToDevice(params, device);
4869 torch::Tensor lazy_result = torch::index(lazy_params, {indices.cpu()});
4870 AllEqual(result, lazy_result);
4871 });
4872 }
4873}
4874
4875TEST_F(LazyOpsTest, TestNonzero) {
4876 torch::Tensor a = torch::zeros(
4877 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4878 a[0][1] = 1.0;
4879 a[1][0] = 2.0;
4880 a[3][1] = 3.0;
4881 torch::Tensor b = torch::nonzero(a);
4882 ForEachDevice([&](const torch::Device& device) {
4883 torch::Tensor lazy_a = CopyToDevice(a, device);
4884 torch::Tensor lazy_b = torch::nonzero(lazy_a);
4885 AllClose(b, lazy_b);
4886
4887 if (DebugUtil::ExperimentEnabled("nonzero")) {
4888 // If the nonzero support is enabled, we must not see any aten:: calls.
4889 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4890 }
4891 ResetCounters();
4892 });
4893}
4894
4895TEST_F(LazyOpsTest, TestMaskedSelect) {
4896 torch::Tensor a = torch::rand(
4897 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4898 torch::Tensor b = torch::randint(
4899 0, 2, {5}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
4900 torch::Tensor c = torch::masked_select(a, b);
4901 ForEachDevice([&](const torch::Device& device) {
4902 torch::Tensor lazy_a = CopyToDevice(a, device);
4903 torch::Tensor lazy_b = CopyToDevice(b, device);
4904 torch::Tensor lazy_c = torch::masked_select(lazy_a, lazy_b);
4905 AllClose(c, lazy_c);
4906
4907 if (DebugUtil::ExperimentEnabled("masked_select")) {
4908 // If the masked_select support is enabled, we must not see any aten::
4909 // calls.
4910 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4911 }
4912 ResetCounters();
4913 });
4914}
4915
4916TEST_F(LazyOpsTest, TestMaskedScatter) {
4917 torch::Tensor a = torch::rand(
4918 {3, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4919 torch::Tensor b = torch::randint(
4920 0, 2, {3, 5}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
4921 torch::Tensor c = torch::rand(
4922 {15}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
4923 torch::Tensor d = torch::masked_scatter(a, b, c);
4924 ForEachDevice([&](const torch::Device& device) {
4925 torch::Tensor lazy_a = CopyToDevice(a, device);
4926 torch::Tensor lazy_b = CopyToDevice(b, device);
4927 torch::Tensor lazy_c = CopyToDevice(c, device);
4928 torch::Tensor lazy_d = torch::masked_scatter(lazy_a, lazy_b, lazy_c);
4929 AllClose(d, lazy_d);
4930
4931 if (DebugUtil::ExperimentEnabled("masked_scatter")) {
4932 // If the masked_select support is enabled, we must not see any aten::
4933 // calls.
4934 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
4935 }
4936 ResetCounters();
4937 });
4938}
4939
4940TEST_F(LazyOpsTest, TestMultiIndexHeadNull) {
4941 for (torch::ScalarType scalar_type :
4942 {torch::kFloat,
4943 torch::kByte,
4944 torch::kChar,
4945 torch::kShort,
4946 torch::kInt,
4947 torch::kLong}) {
4948 torch::Tensor params = isFloatingType(scalar_type)
4949 ? torch::rand(
4950 {4, 3, 5, 6, 7},
4951 torch::TensorOptions(scalar_type).device(DefaultDevice()))
4952 : torch::randint(
4953 100,
4954 {4, 3, 5, 6, 7},
4955 torch::TensorOptions(scalar_type).device(DefaultDevice()));
4956 torch::Tensor indices_null;
4957 torch::Tensor indices_0 = torch::randint(
4958 -3,
4959 3,
4960 {2, 4, 3},
4961 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4962 torch::Tensor indices_1 = torch::randint(
4963 -3,
4964 3,
4965 {2, 4, 3},
4966 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
4967 torch::Tensor result =
4968 torch::index(params, {indices_null, indices_0, indices_1});
4969 ForEachDevice([&](const torch::Device& device) {
4970 torch::Tensor lazy_params = CopyToDevice(params, device);
4971 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
4972 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
4973 torch::Tensor lazy_result = torch::index(
4974 lazy_params, {indices_null, lazy_indices_0, lazy_indices_1});
4975 AllEqual(result, lazy_result);
4976 });
4977 }
4978}
4979
4980TEST_F(LazyOpsTest, TestMultiIndexMiddleNull) {
4981 for (torch::ScalarType scalar_type :
4982 {torch::kFloat,
4983 torch::kByte,
4984 torch::kChar,
4985 torch::kShort,
4986 torch::kInt,
4987 torch::kLong}) {
4988 torch::Tensor params = isFloatingType(scalar_type)
4989 ? torch::rand(
4990 {4, 3, 5, 6, 7},
4991 torch::TensorOptions(scalar_type).device(DefaultDevice()))
4992 : torch::randint(
4993 100,
4994 {4, 3, 5, 6, 7},
4995 torch::TensorOptions(scalar_type).device(DefaultDevice()));
4996 torch::Tensor indices_0 = torch::randint(
4997 -3,
4998 3,
4999 {2, 4, 3},
5000 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5001 torch::Tensor indices_null;
5002 torch::Tensor indices_1 = torch::randint(
5003 -3,
5004 3,
5005 {2, 4, 3},
5006 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5007 torch::Tensor result =
5008 torch::index(params, {indices_0, indices_null, indices_1});
5009 ForEachDevice([&](const torch::Device& device) {
5010 torch::Tensor lazy_params = CopyToDevice(params, device);
5011 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5012 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5013 torch::Tensor lazy_result = torch::index(
5014 lazy_params, {lazy_indices_0, indices_null, lazy_indices_1});
5015 AllEqual(result, lazy_result);
5016 });
5017 }
5018}
5019
5020TEST_F(LazyOpsTest, TestMultiIndexTailNull) {
5021 for (torch::ScalarType scalar_type :
5022 {torch::kFloat,
5023 torch::kByte,
5024 torch::kChar,
5025 torch::kShort,
5026 torch::kInt,
5027 torch::kLong}) {
5028 torch::Tensor params = isFloatingType(scalar_type)
5029 ? torch::rand(
5030 {4, 3, 5, 6, 7},
5031 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5032 : torch::randint(
5033 100,
5034 {4, 3, 5, 6, 7},
5035 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5036 torch::Tensor indices_0 = torch::randint(
5037 -3,
5038 3,
5039 {2, 4, 3},
5040 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5041 torch::Tensor indices_null;
5042 torch::Tensor indices_1 = torch::randint(
5043 -3,
5044 3,
5045 {2, 4, 3},
5046 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5047 torch::Tensor result =
5048 torch::index(params, {indices_0, indices_1, indices_null});
5049 ForEachDevice([&](const torch::Device& device) {
5050 torch::Tensor lazy_params = CopyToDevice(params, device);
5051 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5052 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5053 torch::Tensor lazy_result = torch::index(
5054 lazy_params, {lazy_indices_0, lazy_indices_1, indices_null});
5055 AllEqual(result, lazy_result);
5056 });
5057 }
5058}
5059
5060TEST_F(LazyOpsTest, TestMultiIndexMiddleBroadcast) {
5061 for (torch::ScalarType scalar_type :
5062 {torch::kFloat,
5063 torch::kByte,
5064 torch::kChar,
5065 torch::kShort,
5066 torch::kInt,
5067 torch::kLong}) {
5068 torch::Tensor params = isFloatingType(scalar_type)
5069 ? torch::rand(
5070 {4, 3, 5, 6, 7},
5071 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5072 : torch::randint(
5073 100,
5074 {4, 3, 5, 6, 7},
5075 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5076 torch::Tensor indices_0 = torch::randint(
5077 -3,
5078 3,
5079 {2, 4, 3},
5080 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5081 torch::Tensor indices_1 = torch::randint(
5082 -3,
5083 3,
5084 {2, 1, 3},
5085 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5086 torch::Tensor result = torch::index(params, {indices_0, indices_1});
5087 ForEachDevice([&](const torch::Device& device) {
5088 torch::Tensor lazy_params = CopyToDevice(params, device);
5089 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5090 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5091 torch::Tensor lazy_result =
5092 torch::index(lazy_params, {lazy_indices_0, lazy_indices_1});
5093 AllEqual(result, lazy_result);
5094 });
5095 }
5096}
5097
5098TEST_F(LazyOpsTest, TestMultiIndexTailBroadcast) {
5099 for (torch::ScalarType scalar_type :
5100 {torch::kFloat,
5101 torch::kByte,
5102 torch::kChar,
5103 torch::kShort,
5104 torch::kInt,
5105 torch::kLong}) {
5106 torch::Tensor params = isFloatingType(scalar_type)
5107 ? torch::rand(
5108 {4, 3, 5, 6, 7},
5109 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5110 : torch::randint(
5111 100,
5112 {4, 3, 5, 6, 7},
5113 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5114 torch::Tensor indices_0 = torch::randint(
5115 -3,
5116 3,
5117 {2, 1, 3},
5118 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5119 torch::Tensor indices_1 = torch::randint(
5120 -3,
5121 3,
5122 {2, 1},
5123 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5124 torch::Tensor result = torch::index(params, {indices_0, indices_1});
5125 ForEachDevice([&](const torch::Device& device) {
5126 torch::Tensor lazy_params = CopyToDevice(params, device);
5127 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5128 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5129 torch::Tensor lazy_result =
5130 torch::index(lazy_params, {lazy_indices_0, lazy_indices_1});
5131 AllEqual(result, lazy_result);
5132 });
5133 }
5134}
5135
5136TEST_F(LazyOpsTest, TestMaskIndex) {
5137 for (torch::ScalarType scalar_type :
5138 {torch::kFloat,
5139 torch::kByte,
5140 torch::kChar,
5141 torch::kShort,
5142 torch::kInt,
5143 torch::kLong}) {
5144 torch::Tensor params = isFloatingType(scalar_type)
5145 ? torch::rand(
5146 {2, 2}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
5147 : torch::randint(
5148 100,
5149 {2, 2},
5150 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5151 torch::Tensor indices = torch::randint(
5152 0,
5153 2,
5154 {2, 2},
5155 torch::TensorOptions(torch::kBool).device(DefaultDevice()));
5156 torch::Tensor result = torch::index(params, {indices});
5157 ForEachDevice([&](const torch::Device& device) {
5158 torch::Tensor lazy_params = CopyToDevice(params, device);
5159 torch::Tensor lazy_indices = CopyToDevice(indices, device);
5160 torch::Tensor lazy_result = torch::index(lazy_params, {lazy_indices});
5161 AllEqual(result, lazy_result);
5162 });
5163 }
5164}
5165
5166TEST_F(LazyOpsTest, TestOneIndexPut) {
5167 for (torch::ScalarType scalar_type :
5168 {torch::kFloat,
5169 torch::kByte,
5170 torch::kChar,
5171 torch::kShort,
5172 torch::kInt,
5173 torch::kLong}) {
5174 torch::Tensor params = isFloatingType(scalar_type)
5175 ? torch::rand(
5176 {4, 3, 5, 6, 7},
5177 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5178 : torch::randint(
5179 100,
5180 {4, 3, 5, 6, 7},
5181 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5182 torch::Tensor indices = torch::randint(
5183 -3,
5184 3,
5185 {2, 4, 3},
5186 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5187 torch::Tensor values = isFloatingType(scalar_type)
5188 ? torch::rand(
5189 {3, 5, 6, 7},
5190 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5191 : torch::randint(
5192 100,
5193 {3, 5, 6, 7},
5194 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5195 for (bool accumulate : {false, true}) {
5196 if (accumulate && IsCuda()) {
5197 GTEST_SKIP();
5198 }
5199 torch::Tensor result =
5200 torch::index_put(params, {indices}, values, accumulate);
5201 ForEachDevice([&](const torch::Device& device) {
5202 torch::Tensor lazy_params = CopyToDevice(params, device);
5203 torch::Tensor lazy_indices = CopyToDevice(indices, device);
5204 torch::Tensor lazy_values = CopyToDevice(values, device);
5205 torch::Tensor lazy_result = torch::index_put(
5206 lazy_params, {lazy_indices}, lazy_values, accumulate);
5207 AllEqual(result, lazy_result);
5208 });
5209 }
5210 }
5211}
5212
5213TEST_F(LazyOpsTest, TestOneIndexPutInPlace) {
5214 torch::Tensor indices = torch::randint(
5215 -3,
5216 3,
5217 {2, 4, 3},
5218 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5219 for (torch::ScalarType scalar_type :
5220 {torch::kFloat,
5221 torch::kByte,
5222 torch::kChar,
5223 torch::kShort,
5224 torch::kInt,
5225 torch::kLong}) {
5226 torch::Tensor values = torch::ones(
5227 {3, 5, 6, 7},
5228 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5229 for (bool accumulate : {false, true}) {
5230 if (accumulate && IsCuda()) {
5231 GTEST_SKIP();
5232 }
5233 ForEachDevice([&](const torch::Device& device) {
5234 torch::Tensor params = isFloatingType(scalar_type)
5235 ? torch::rand(
5236 {4, 3, 5, 6, 7},
5237 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5238 : torch::randint(
5239 100,
5240 {4, 3, 5, 6, 7},
5241 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5242 torch::Tensor lazy_params = CopyToDevice(params.clone(), device);
5243 torch::Tensor result =
5244 torch::index_put_(params, {indices}, values, accumulate);
5245 torch::Tensor lazy_indices = CopyToDevice(indices, device);
5246 torch::Tensor lazy_values = CopyToDevice(values, device);
5247 torch::Tensor lazy_result = torch::index_put_(
5248 lazy_params, {lazy_indices}, lazy_values, accumulate);
5249 AllEqual(result, lazy_result);
5250 AllEqual(params, lazy_params);
5251 });
5252 }
5253 }
5254}
5255
5256TEST_F(LazyOpsTest, TestOneIndexPutTransfer) {
5257 torch::Tensor indices = torch::randint(
5258 -3,
5259 3,
5260 {2, 4, 3},
5261 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5262 for (torch::ScalarType scalar_type :
5263 {torch::kFloat,
5264 torch::kByte,
5265 torch::kChar,
5266 torch::kShort,
5267 torch::kInt,
5268 torch::kLong}) {
5269 torch::Tensor params = isFloatingType(scalar_type)
5270 ? torch::rand(
5271 {4, 3, 5, 6, 7},
5272 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5273 : torch::randint(
5274 100,
5275 {4, 3, 5, 6, 7},
5276 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5277 torch::Tensor values = torch::ones(
5278 {3, 5, 6, 7},
5279 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5280 for (bool accumulate : {false, true}) {
5281 if (accumulate && IsCuda()) {
5282 GTEST_SKIP();
5283 }
5284 torch::Tensor result =
5285 torch::index_put(params, {indices}, values, accumulate);
5286 ForEachDevice([&](const torch::Device& device) {
5287 torch::Tensor lazy_params = CopyToDevice(params, device);
5288 torch::Tensor lazy_values = CopyToDevice(values, device);
5289 torch::Tensor lazy_result =
5290 torch::index_put(lazy_params, {indices}, lazy_values, accumulate);
5291 AllEqual(result, lazy_result);
5292 });
5293 }
5294 }
5295}
5296
5297TEST_F(LazyOpsTest, TestMultiIndexPut) {
5298 torch::Tensor indices_0 = torch::randint(
5299 -3,
5300 3,
5301 {2, 4, 3},
5302 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5303 torch::Tensor indices_1 = torch::randint(
5304 -3,
5305 3,
5306 {2, 4, 3},
5307 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5308 for (torch::ScalarType scalar_type :
5309 {torch::kFloat,
5310 torch::kByte,
5311 torch::kChar,
5312 torch::kShort,
5313 torch::kInt,
5314 torch::kLong}) {
5315 torch::Tensor params = isFloatingType(scalar_type)
5316 ? torch::rand(
5317 {4, 3, 5, 6, 7},
5318 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5319 : torch::randint(
5320 100,
5321 {4, 3, 5, 6, 7},
5322 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5323 torch::Tensor values = torch::ones(
5324 {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5325 for (bool accumulate : {false, true}) {
5326 if (accumulate && IsCuda()) {
5327 GTEST_SKIP();
5328 }
5329 torch::Tensor result =
5330 torch::index_put(params, {indices_0, indices_1}, values, accumulate);
5331 ForEachDevice([&](const torch::Device& device) {
5332 torch::Tensor lazy_params = CopyToDevice(params, device);
5333 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5334 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5335 torch::Tensor lazy_values = CopyToDevice(values, device);
5336 torch::Tensor lazy_result = torch::index_put(
5337 lazy_params,
5338 {lazy_indices_0, lazy_indices_1},
5339 lazy_values,
5340 accumulate);
5341 AllEqual(result, lazy_result);
5342 });
5343 }
5344 }
5345}
5346
5347TEST_F(LazyOpsTest, TestMultiIndexPutHeadNull) {
5348 torch::Tensor indices_0 = torch::randint(
5349 -3,
5350 3,
5351 {2, 4, 3},
5352 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5353 torch::Tensor indices_null;
5354 torch::Tensor indices_1 = torch::randint(
5355 -3,
5356 3,
5357 {2, 4, 3},
5358 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5359 for (torch::ScalarType scalar_type :
5360 {torch::kFloat,
5361 torch::kByte,
5362 torch::kChar,
5363 torch::kShort,
5364 torch::kInt,
5365 torch::kLong}) {
5366 torch::Tensor params = isFloatingType(scalar_type)
5367 ? torch::rand(
5368 {4, 3, 3, 6, 7},
5369 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5370 : torch::randint(
5371 100,
5372 {4, 3, 3, 6, 7},
5373 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5374 torch::Tensor values = torch::ones(
5375 {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5376 for (bool accumulate : {false, true}) {
5377 if (accumulate && IsCuda()) {
5378 GTEST_SKIP();
5379 }
5380 torch::Tensor result = torch::index_put(
5381 params, {indices_null, indices_0, indices_1}, values, accumulate);
5382 ForEachDevice([&](const torch::Device& device) {
5383 torch::Tensor lazy_params = CopyToDevice(params, device);
5384 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5385 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5386 torch::Tensor lazy_values = CopyToDevice(values, device);
5387 torch::Tensor lazy_result = torch::index_put(
5388 lazy_params,
5389 {indices_null, lazy_indices_0, lazy_indices_1},
5390 lazy_values,
5391 accumulate);
5392 AllEqual(result, lazy_result);
5393 });
5394 }
5395 }
5396}
5397
5398TEST_F(LazyOpsTest, TestMultiIndexPutMiddleNull) {
5399 torch::Tensor indices_0 = torch::randint(
5400 -3,
5401 3,
5402 {2, 4, 3},
5403 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5404 torch::Tensor indices_null;
5405 torch::Tensor indices_1 = torch::randint(
5406 -3,
5407 3,
5408 {2, 4, 3},
5409 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5410 for (torch::ScalarType scalar_type :
5411 {torch::kFloat,
5412 torch::kByte,
5413 torch::kChar,
5414 torch::kShort,
5415 torch::kInt,
5416 torch::kLong}) {
5417 torch::Tensor params = isFloatingType(scalar_type)
5418 ? torch::rand(
5419 {4, 3, 3, 6, 7},
5420 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5421 : torch::randint(
5422 100,
5423 {4, 3, 3, 6, 7},
5424 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5425 torch::Tensor values = torch::ones(
5426 {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5427 for (bool accumulate : {false, true}) {
5428 if (accumulate && IsCuda()) {
5429 GTEST_SKIP();
5430 }
5431 torch::Tensor result = torch::index_put(
5432 params, {indices_0, indices_null, indices_1}, values, accumulate);
5433 ForEachDevice([&](const torch::Device& device) {
5434 torch::Tensor lazy_params = CopyToDevice(params, device);
5435 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5436 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5437 torch::Tensor lazy_values = CopyToDevice(values, device);
5438 torch::Tensor lazy_result = torch::index_put(
5439 lazy_params,
5440 {lazy_indices_0, indices_null, lazy_indices_1},
5441 lazy_values,
5442 accumulate);
5443 AllEqual(result, lazy_result);
5444 });
5445 }
5446 }
5447}
5448
5449TEST_F(LazyOpsTest, TestMultiIndexPutTailNull) {
5450 torch::Tensor indices_0 = torch::randint(
5451 -3,
5452 3,
5453 {2, 4, 3},
5454 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5455 torch::Tensor indices_1 = torch::randint(
5456 -3,
5457 3,
5458 {2, 4, 3},
5459 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5460 torch::Tensor indices_null;
5461 for (torch::ScalarType scalar_type :
5462 {torch::kFloat,
5463 torch::kByte,
5464 torch::kChar,
5465 torch::kShort,
5466 torch::kInt,
5467 torch::kLong}) {
5468 torch::Tensor params = isFloatingType(scalar_type)
5469 ? torch::rand(
5470 {4, 3, 3, 6, 7},
5471 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5472 : torch::randint(
5473 100,
5474 {4, 3, 3, 6, 7},
5475 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5476 torch::Tensor values = torch::ones(
5477 {3, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5478 for (bool accumulate : {false, true}) {
5479 if (accumulate && IsCuda()) {
5480 GTEST_SKIP();
5481 }
5482 torch::Tensor result = torch::index_put(
5483 params, {indices_0, indices_1, indices_null}, values, accumulate);
5484 ForEachDevice([&](const torch::Device& device) {
5485 torch::Tensor lazy_params = CopyToDevice(params, device);
5486 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5487 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5488 torch::Tensor lazy_values = CopyToDevice(values, device);
5489 torch::Tensor lazy_result = torch::index_put(
5490 lazy_params,
5491 {lazy_indices_0, lazy_indices_1, indices_null},
5492 lazy_values,
5493 accumulate);
5494 AllEqual(result, lazy_result);
5495 });
5496 }
5497 }
5498}
5499
5500TEST_F(LazyOpsTest, TestMultiIndexPutMiddleBroadcast) {
5501 torch::Tensor indices_0 = torch::randint(
5502 -3,
5503 3,
5504 {2, 4, 3},
5505 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5506 torch::Tensor indices_1 = torch::randint(
5507 -3,
5508 3,
5509 {2, 1, 3},
5510 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5511 for (torch::ScalarType scalar_type :
5512 {torch::kFloat,
5513 torch::kByte,
5514 torch::kChar,
5515 torch::kShort,
5516 torch::kInt,
5517 torch::kLong}) {
5518 torch::Tensor params = isFloatingType(scalar_type)
5519 ? torch::rand(
5520 {4, 3, 5, 6, 7},
5521 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5522 : torch::randint(
5523 100,
5524 {4, 3, 5, 6, 7},
5525 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5526 torch::Tensor values = torch::ones(
5527 {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5528 for (bool accumulate : {false, true}) {
5529 if (accumulate && IsCuda()) {
5530 GTEST_SKIP();
5531 }
5532 torch::Tensor result =
5533 torch::index_put(params, {indices_0, indices_1}, values, accumulate);
5534 ForEachDevice([&](const torch::Device& device) {
5535 torch::Tensor lazy_params = CopyToDevice(params, device);
5536 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5537 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5538 torch::Tensor lazy_values = CopyToDevice(values, device);
5539 torch::Tensor lazy_result = torch::index_put(
5540 lazy_params,
5541 {lazy_indices_0, lazy_indices_1},
5542 lazy_values,
5543 accumulate);
5544 AllEqual(result, lazy_result);
5545 });
5546 }
5547 }
5548}
5549
5550TEST_F(LazyOpsTest, TestMultiIndexPutTailBroadcast) {
5551 torch::Tensor indices_0 = torch::randint(
5552 -3,
5553 3,
5554 {2, 1, 3},
5555 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5556 torch::Tensor indices_1 = torch::randint(
5557 -3,
5558 3,
5559 {2, 1},
5560 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5561 for (torch::ScalarType scalar_type :
5562 {torch::kFloat,
5563 torch::kByte,
5564 torch::kChar,
5565 torch::kShort,
5566 torch::kInt,
5567 torch::kLong}) {
5568 torch::Tensor params = isFloatingType(scalar_type)
5569 ? torch::rand(
5570 {4, 3, 5, 6, 7},
5571 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5572 : torch::randint(
5573 100,
5574 {4, 3, 5, 6, 7},
5575 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5576 torch::Tensor values = torch::ones(
5577 {5, 6, 7}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5578 for (bool accumulate : {false, true}) {
5579 if (accumulate && IsCuda()) {
5580 GTEST_SKIP();
5581 }
5582 torch::Tensor result =
5583 torch::index_put(params, {indices_0, indices_1}, values, accumulate);
5584 ForEachDevice([&](const torch::Device& device) {
5585 torch::Tensor lazy_params = CopyToDevice(params, device);
5586 torch::Tensor lazy_indices_0 = CopyToDevice(indices_0, device);
5587 torch::Tensor lazy_indices_1 = CopyToDevice(indices_1, device);
5588 torch::Tensor lazy_values = CopyToDevice(values, device);
5589 torch::Tensor lazy_result = torch::index_put(
5590 lazy_params,
5591 {lazy_indices_0, lazy_indices_1},
5592 lazy_values,
5593 accumulate);
5594 AllEqual(result, lazy_result);
5595 });
5596 }
5597 }
5598}
5599
5600TEST_F(LazyOpsTest, TestMaskIndexPut) {
5601 torch::Tensor indices =
5602 torch::tensor(
5603 {0, 1}, torch::TensorOptions(torch::kByte).device(DefaultDevice()))
5604 .to(torch::kBool);
5605 for (torch::ScalarType scalar_type :
5606 {torch::kFloat,
5607 torch::kByte,
5608 torch::kChar,
5609 torch::kShort,
5610 torch::kInt,
5611 torch::kLong}) {
5612 torch::Tensor params = isFloatingType(scalar_type)
5613 ? torch::rand(
5614 {2, 2}, torch::TensorOptions(scalar_type).device(DefaultDevice()))
5615 : torch::randint(
5616 100,
5617 {2, 2},
5618 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5619 torch::Tensor values = torch::ones(
5620 {2}, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5621 for (bool accumulate : {false, true}) {
5622 torch::Tensor result =
5623 torch::index_put(params, {indices}, values, accumulate);
5624 ForEachDevice([&](const torch::Device& device) {
5625 torch::Tensor lazy_params = CopyToDevice(params, device);
5626 torch::Tensor lazy_indices = CopyToDevice(indices, device);
5627 torch::Tensor lazy_values = CopyToDevice(values, device);
5628 torch::Tensor lazy_result = torch::index_put(
5629 lazy_params, {lazy_indices}, lazy_values, accumulate);
5630 AllEqual(result, lazy_result);
5631 });
5632 }
5633 }
5634}
5635
5636TEST_F(LazyOpsTest, TestIndexPutImpl) {
5637 torch::Tensor indices = torch::randint(
5638 -3,
5639 3,
5640 {2, 4, 3},
5641 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5642 for (torch::ScalarType scalar_type :
5643 {torch::kFloat,
5644 torch::kByte,
5645 torch::kChar,
5646 torch::kShort,
5647 torch::kInt,
5648 torch::kLong}) {
5649 torch::Tensor values = torch::ones(
5650 {3, 5, 6, 7},
5651 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5652 for (bool accumulate : {false, true}) {
5653 if (accumulate && IsCuda()) {
5654 GTEST_SKIP();
5655 }
5656 ForEachDevice([&](const torch::Device& device) {
5657 torch::Tensor params = isFloatingType(scalar_type)
5658 ? torch::rand(
5659 {4, 3, 5, 6, 7},
5660 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5661 : torch::randint(
5662 100,
5663 {4, 3, 5, 6, 7},
5664 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5665 torch::Tensor lazy_params = CopyToDevice(params.clone(), device);
5666 torch::Tensor result = torch::_index_put_impl_(
5667 params, {indices}, values, accumulate, /*unsafe=*/true);
5668 torch::Tensor lazy_indices = CopyToDevice(indices, device);
5669 torch::Tensor lazy_values = CopyToDevice(values, device);
5670 torch::Tensor lazy_result = torch::_index_put_impl_(
5671 lazy_params,
5672 {lazy_indices},
5673 lazy_values,
5674 accumulate,
5675 /*unsafe=*/true);
5676 AllEqual(result, lazy_result);
5677 AllEqual(params, lazy_params);
5678 });
5679 }
5680 }
5681}
5682
5683TEST_F(LazyOpsTest, TestIndexFillWithScalar) {
5684 torch::Tensor index = torch::tensor(
5685 {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5686 torch::Scalar value = 42;
5687 for (torch::ScalarType scalar_type :
5688 {torch::kFloat,
5689 torch::kByte,
5690 torch::kChar,
5691 torch::kShort,
5692 torch::kInt,
5693 torch::kLong}) {
5694 torch::Tensor base = isFloatingType(scalar_type)
5695 ? torch::rand(
5696 {3, 4, 5},
5697 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5698 : torch::randint(
5699 100,
5700 {3, 4, 5},
5701 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5702 int rank = base.dim();
5703 for (int dim = -rank; dim < rank; ++dim) {
5704 torch::Tensor result = torch::index_fill(base, dim, index, value);
5705 ForEachDevice([&](const torch::Device& device) {
5706 torch::Tensor lazy_base = CopyToDevice(base, device);
5707 torch::Tensor lazy_index = CopyToDevice(index, device);
5708 torch::Tensor lazy_result =
5709 torch::index_fill(lazy_base, dim, lazy_index, value);
5710 AllEqual(result, lazy_result);
5711 });
5712 }
5713 }
5714}
5715
5716TEST_F(LazyOpsTest, TestIndexFillWithScalarInPlace) {
5717 torch::Tensor index = torch::tensor(
5718 {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5719 torch::Scalar value = 42;
5720 int rank = 3;
5721 for (torch::ScalarType scalar_type :
5722 {torch::kFloat,
5723 torch::kByte,
5724 torch::kChar,
5725 torch::kShort,
5726 torch::kInt,
5727 torch::kLong}) {
5728 for (int dim = -rank; dim < rank; ++dim) {
5729 ForEachDevice([&](const torch::Device& device) {
5730 torch::Tensor base = isFloatingType(scalar_type)
5731 ? torch::rand(
5732 {3, 4, 5},
5733 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5734 : torch::randint(
5735 100,
5736 {3, 4, 5},
5737 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5738 torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
5739 torch::Tensor result = base.index_fill_(dim, index, value);
5740 torch::Tensor lazy_index = CopyToDevice(index, device);
5741 torch::Tensor lazy_result =
5742 lazy_base.index_fill_(dim, lazy_index, value);
5743 AllEqual(result, lazy_result);
5744 AllEqual(base, lazy_base);
5745 });
5746 }
5747 }
5748}
5749
5750TEST_F(LazyOpsTest, TestIndexFillWithTensor) {
5751 torch::Tensor index = torch::tensor(
5752 {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5753 for (torch::ScalarType scalar_type :
5754 {torch::kFloat,
5755 torch::kByte,
5756 torch::kChar,
5757 torch::kShort,
5758 torch::kInt,
5759 torch::kLong}) {
5760 torch::Tensor base = isFloatingType(scalar_type)
5761 ? torch::rand(
5762 {3, 4, 5},
5763 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5764 : torch::randint(
5765 100,
5766 {3, 4, 5},
5767 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5768 torch::Tensor value = torch::scalar_tensor(
5769 42, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5770 int rank = base.dim();
5771 for (int dim = -rank; dim < rank; ++dim) {
5772 torch::Tensor result = torch::index_fill(base, dim, index, value);
5773 ForEachDevice([&](const torch::Device& device) {
5774 torch::Tensor lazy_base = CopyToDevice(base, device);
5775 torch::Tensor lazy_index = CopyToDevice(index, device);
5776 torch::Tensor lazy_value = CopyToDevice(value, device);
5777 torch::Tensor lazy_result =
5778 torch::index_fill(lazy_base, dim, lazy_index, lazy_value);
5779 AllEqual(result, lazy_result);
5780 });
5781 }
5782 }
5783}
5784
5785TEST_F(LazyOpsTest, TestIndexFillWithTensorInPlace) {
5786 torch::Tensor index = torch::tensor(
5787 {0, 2}, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5788 for (torch::ScalarType scalar_type :
5789 {torch::kFloat,
5790 torch::kByte,
5791 torch::kChar,
5792 torch::kShort,
5793 torch::kInt,
5794 torch::kLong}) {
5795 torch::Tensor value = torch::scalar_tensor(
5796 42, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5797 int rank = 3;
5798 for (int dim = -rank; dim < rank; ++dim) {
5799 ForEachDevice([&](const torch::Device& device) {
5800 torch::Tensor base = isFloatingType(scalar_type)
5801 ? torch::rand(
5802 {3, 4, 5},
5803 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5804 : torch::randint(
5805 100,
5806 {3, 4, 5},
5807 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5808 torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
5809 torch::Tensor result = base.index_fill_(dim, index, value);
5810 torch::Tensor lazy_index = CopyToDevice(index, device);
5811 torch::Tensor lazy_value = CopyToDevice(value, device);
5812 torch::Tensor lazy_result =
5813 lazy_base.index_fill_(dim, lazy_index, lazy_value);
5814 AllEqual(result, lazy_result);
5815 AllEqual(base, lazy_base);
5816 });
5817 }
5818 }
5819}
5820
5821TEST_F(LazyOpsTest, TestIndexFillRank0) {
5822 torch::Tensor index = torch::scalar_tensor(
5823 2, torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5824 for (torch::ScalarType scalar_type :
5825 {torch::kFloat,
5826 torch::kByte,
5827 torch::kChar,
5828 torch::kShort,
5829 torch::kInt,
5830 torch::kLong}) {
5831 torch::Tensor base = isFloatingType(scalar_type)
5832 ? torch::rand(
5833 {3, 4, 5},
5834 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5835 : torch::randint(
5836 100,
5837 {3, 4, 5},
5838 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5839 torch::Tensor value = torch::scalar_tensor(
5840 42, torch::TensorOptions(scalar_type).device(DefaultDevice()));
5841 int rank = base.dim();
5842 for (int dim = -rank; dim < rank; ++dim) {
5843 torch::Tensor result = torch::index_fill(base, dim, index, value);
5844 ForEachDevice([&](const torch::Device& device) {
5845 torch::Tensor lazy_base = CopyToDevice(base, device);
5846 torch::Tensor lazy_index = CopyToDevice(index, device);
5847 torch::Tensor lazy_value = CopyToDevice(value, device);
5848 torch::Tensor lazy_result =
5849 torch::index_fill(lazy_base, dim, lazy_index, lazy_value);
5850 AllEqual(result, lazy_result);
5851 });
5852 }
5853 }
5854}
5855
5856TEST_F(LazyOpsTest, TestIndexAdd) {
5857 int index_size = 10;
5858 for (torch::ScalarType scalar_type :
5859 {torch::kFloat,
5860 torch::kByte,
5861 torch::kChar,
5862 torch::kShort,
5863 torch::kInt,
5864 torch::kLong}) {
5865 torch::Tensor base = isFloatingType(scalar_type)
5866 ? torch::rand(
5867 {5, 3, 7},
5868 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5869 : torch::randint(
5870 100,
5871 {5, 3, 7},
5872 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5873 int rank = base.dim();
5874 for (int dim = -rank; dim < rank; ++dim) {
5875 for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) {
5876 torch::Tensor index = torch::randint(
5877 0,
5878 base.size(dim),
5879 {index_size},
5880 torch::TensorOptions(index_scalar_type).device(DefaultDevice()));
5881 std::vector<int64_t> value_sizes(
5882 base.sizes().begin(), base.sizes().end());
5883 int canonical_dim = dim < 0 ? dim + rank : dim;
5884 value_sizes[canonical_dim] = index_size;
5885 torch::Tensor value = isFloatingType(scalar_type)
5886 ? torch::rand(
5887 value_sizes,
5888 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5889 : torch::randint(
5890 100,
5891 value_sizes,
5892 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5893 torch::Tensor result = torch::index_add(base, dim, index, value);
5894 ForEachDevice([&](const torch::Device& device) {
5895 torch::Tensor lazy_base = CopyToDevice(base, device);
5896 torch::Tensor lazy_index = CopyToDevice(index, device);
5897 torch::Tensor lazy_value = CopyToDevice(value, device);
5898 torch::Tensor lazy_result =
5899 torch::index_add(lazy_base, dim, lazy_index, lazy_value);
5900 AllClose(result, lazy_result);
5901 });
5902 }
5903 }
5904 }
5905}
5906
5907TEST_F(LazyOpsTest, TestIndexAddInPlace) {
5908 int index_size = 10;
5909 int rank = 3;
5910 for (torch::ScalarType scalar_type :
5911 {torch::kFloat,
5912 torch::kByte,
5913 torch::kChar,
5914 torch::kShort,
5915 torch::kInt,
5916 torch::kLong}) {
5917 for (int dim = -rank; dim < rank; ++dim) {
5918 ForEachDevice([&](const torch::Device& device) {
5919 torch::Tensor base = isFloatingType(scalar_type)
5920 ? torch::rand(
5921 {5, 3, 7},
5922 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5923 : torch::randint(
5924 100,
5925 {5, 3, 7},
5926 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5927 torch::Tensor index = torch::randint(
5928 0,
5929 base.size(dim),
5930 {index_size},
5931 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5932 std::vector<int64_t> value_sizes(
5933 base.sizes().begin(), base.sizes().end());
5934 int canonical_dim = dim < 0 ? dim + rank : dim;
5935 value_sizes[canonical_dim] = index_size;
5936 torch::Tensor value = isFloatingType(scalar_type)
5937 ? torch::rand(
5938 value_sizes,
5939 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5940 : torch::randint(
5941 100,
5942 value_sizes,
5943 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5944 torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
5945 torch::Tensor result = base.index_add_(dim, index, value);
5946 torch::Tensor lazy_index = CopyToDevice(index, device);
5947 torch::Tensor lazy_value = CopyToDevice(value, device);
5948 torch::Tensor lazy_result =
5949 lazy_base.index_add_(dim, lazy_index, lazy_value);
5950 AllClose(result, lazy_result);
5951 AllClose(base, lazy_base);
5952 });
5953 }
5954 }
5955}
5956
5957TEST_F(LazyOpsTest, TestIndexAddRank0) {
5958 for (torch::ScalarType scalar_type :
5959 {torch::kFloat,
5960 torch::kByte,
5961 torch::kChar,
5962 torch::kShort,
5963 torch::kInt,
5964 torch::kLong}) {
5965 torch::Tensor base = isFloatingType(scalar_type)
5966 ? torch::rand(
5967 {5, 3, 7},
5968 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5969 : torch::randint(
5970 100,
5971 {5, 3, 7},
5972 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5973 int rank = base.dim();
5974 for (int dim = -rank; dim < rank; ++dim) {
5975 torch::Tensor index = torch::randint(
5976 0,
5977 base.size(dim),
5978 at::IntArrayRef{},
5979 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
5980 std::vector<int64_t> value_sizes(
5981 base.sizes().begin(), base.sizes().end());
5982 int canonical_dim = dim < 0 ? dim + rank : dim;
5983 value_sizes[canonical_dim] = 1;
5984 torch::Tensor value = isFloatingType(scalar_type)
5985 ? torch::rand(
5986 value_sizes,
5987 torch::TensorOptions(scalar_type).device(DefaultDevice()))
5988 : torch::randint(
5989 100,
5990 value_sizes,
5991 torch::TensorOptions(scalar_type).device(DefaultDevice()));
5992 torch::Tensor result = torch::index_add(base, dim, index, value);
5993 ForEachDevice([&](const torch::Device& device) {
5994 torch::Tensor lazy_base = CopyToDevice(base, device);
5995 torch::Tensor lazy_index = CopyToDevice(index, device);
5996 torch::Tensor lazy_value = CopyToDevice(value, device);
5997 torch::Tensor lazy_result =
5998 torch::index_add(lazy_base, dim, lazy_index, lazy_value);
5999 AllEqual(result, lazy_result);
6000 });
6001 }
6002 }
6003}
6004
6005TEST_F(LazyOpsTest, TestIndexCopy) {
6006 for (torch::ScalarType scalar_type :
6007 {torch::kFloat,
6008 torch::kByte,
6009 torch::kChar,
6010 torch::kShort,
6011 torch::kInt,
6012 torch::kLong}) {
6013 torch::Tensor base = isFloatingType(scalar_type)
6014 ? torch::rand(
6015 {5, 3, 7},
6016 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6017 : torch::randint(
6018 100,
6019 {5, 3, 7},
6020 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6021 int rank = base.dim();
6022 for (int dim = -rank; dim < rank; ++dim) {
6023 torch::Tensor index = torch::randperm(
6024 base.size(dim),
6025 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6026 torch::Tensor value = isFloatingType(scalar_type)
6027 ? torch::rand(
6028 base.sizes(),
6029 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6030 : torch::randint(
6031 100,
6032 base.sizes(),
6033 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6034 torch::Tensor result = torch::index_copy(base, dim, index, value);
6035 ForEachDevice([&](const torch::Device& device) {
6036 torch::Tensor lazy_base = CopyToDevice(base, device);
6037 torch::Tensor lazy_index = CopyToDevice(index, device);
6038 torch::Tensor lazy_value = CopyToDevice(value, device);
6039 torch::Tensor lazy_result =
6040 torch::index_copy(lazy_base, dim, lazy_index, lazy_value);
6041 AllEqual(result, lazy_result);
6042 });
6043 }
6044 }
6045}
6046
6047TEST_F(LazyOpsTest, TestIndexCopyInPlace) {
6048 if (IsCuda()) {
6049 GTEST_SKIP();
6050 }
6051 int index_size = 10;
6052 int rank = 3;
6053 for (torch::ScalarType scalar_type :
6054 {torch::kFloat,
6055 torch::kByte,
6056 torch::kChar,
6057 torch::kShort,
6058 torch::kInt,
6059 torch::kLong}) {
6060 for (int dim = -rank; dim < rank; ++dim) {
6061 ForEachDevice([&](const torch::Device& device) {
6062 torch::Tensor base = isFloatingType(scalar_type)
6063 ? torch::rand(
6064 {5, 3, 7},
6065 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6066 : torch::randint(
6067 100,
6068 {5, 3, 7},
6069 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6070 torch::Tensor index = torch::randint(
6071 0,
6072 base.size(dim),
6073 {index_size},
6074 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6075 std::vector<int64_t> value_sizes(
6076 base.sizes().begin(), base.sizes().end());
6077 int canonical_dim = dim < 0 ? dim + rank : dim;
6078 value_sizes[canonical_dim] = index_size;
6079 torch::Tensor value = isFloatingType(scalar_type)
6080 ? torch::rand(
6081 value_sizes,
6082 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6083 : torch::randint(
6084 100,
6085 value_sizes,
6086 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6087 torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
6088 torch::Tensor result = base.index_copy_(dim, index, value);
6089 torch::Tensor lazy_index = CopyToDevice(index, device);
6090 torch::Tensor lazy_value = CopyToDevice(value, device);
6091 torch::Tensor lazy_result =
6092 lazy_base.index_copy_(dim, lazy_index, lazy_value);
6093 AllEqual(result, lazy_result);
6094 AllEqual(base, lazy_base);
6095 });
6096 }
6097 }
6098}
6099
6100TEST_F(LazyOpsTest, TestIndexCopyRank0) {
6101 for (torch::ScalarType scalar_type :
6102 {torch::kFloat,
6103 torch::kByte,
6104 torch::kChar,
6105 torch::kShort,
6106 torch::kInt,
6107 torch::kLong}) {
6108 torch::Tensor base = isFloatingType(scalar_type)
6109 ? torch::rand(
6110 {5, 3, 7},
6111 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6112 : torch::randint(
6113 100,
6114 {5, 3, 7},
6115 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6116 int rank = base.dim();
6117 for (int dim = -rank; dim < rank; ++dim) {
6118 torch::Tensor index = torch::randint(
6119 0,
6120 base.size(dim),
6121 at::IntArrayRef{},
6122 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6123 std::vector<int64_t> value_sizes(
6124 base.sizes().begin(), base.sizes().end());
6125 int canonical_dim = dim < 0 ? dim + rank : dim;
6126 value_sizes[canonical_dim] = 1;
6127 torch::Tensor value = isFloatingType(scalar_type)
6128 ? torch::rand(
6129 value_sizes,
6130 torch::TensorOptions(scalar_type).device(DefaultDevice()))
6131 : torch::randint(
6132 100,
6133 value_sizes,
6134 torch::TensorOptions(scalar_type).device(DefaultDevice()));
6135 torch::Tensor result = torch::index_copy(base, dim, index, value);
6136 ForEachDevice([&](const torch::Device& device) {
6137 torch::Tensor lazy_base = CopyToDevice(base, device);
6138 torch::Tensor lazy_index = CopyToDevice(index, device);
6139 torch::Tensor lazy_value = CopyToDevice(value, device);
6140 torch::Tensor lazy_result =
6141 torch::index_copy(lazy_base, dim, lazy_index, lazy_value);
6142 AllEqual(result, lazy_result);
6143 });
6144 }
6145 }
6146}
6147
6148TEST_F(LazyOpsTest, TestRelu) {
6149 torch::Tensor input = torch::rand(
6150 {2, 1, 4, 6},
6151 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6152 torch::Tensor output = torch::relu(input);
6153 ForEachDevice([&](const torch::Device& device) {
6154 torch::Tensor lazy_input = CopyToDevice(input, device);
6155 torch::Tensor lazy_output = torch::relu(lazy_input);
6156 AllClose(output, lazy_output);
6157 });
6158}
6159
6160TEST_F(LazyOpsTest, TestReluInPlace) {
6161 torch::Tensor input = torch::rand(
6162 {2, 1, 4, 6},
6163 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6164 ForEachDevice([&](const torch::Device& device) {
6165 torch::Tensor lazy_input = CopyToDevice(input, device);
6166 torch::Tensor output = torch::relu_(input);
6167 torch::Tensor lazy_output = torch::relu_(lazy_input);
6168 AllClose(output, lazy_output);
6169 AllClose(input, lazy_input);
6170 });
6171}
6172
6173TEST_F(LazyOpsTest, TestHardshrink) {
6174 torch::Tensor input = torch::randn(
6175 {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6176 torch::Tensor output = torch::hardshrink(input);
6177 ForEachDevice([&](const torch::Device& device) {
6178 torch::Tensor lazy_input = CopyToDevice(input, device);
6179 torch::Tensor lazy_output = torch::hardshrink(lazy_input);
6180 AllClose(output, lazy_output);
6181 });
6182}
6183
6184TEST_F(LazyOpsTest, TestHardSigmoid) {
6185 torch::Tensor input = torch::randn(
6186 {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6187 torch::Tensor output = torch::hardsigmoid(input);
6188 ForEachDevice([&](const torch::Device& device) {
6189 torch::Tensor lazy_input = CopyToDevice(input, device);
6190 torch::Tensor lazy_output = torch::hardsigmoid(lazy_input);
6191 AllClose(output, lazy_output);
6192 });
6193}
6194
6195TEST_F(LazyOpsTest, TestHardSigmoidInPlace) {
6196 ForEachDevice([&](const torch::Device& device) {
6197 torch::Tensor input = torch::randn(
6198 {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6199 torch::Tensor lazy_input = CopyToDevice(input, device);
6200 torch::Tensor output = torch::hardsigmoid_(input);
6201 torch::Tensor lazy_output = torch::hardsigmoid_(lazy_input);
6202 AllClose(input, lazy_input);
6203 AllClose(output, lazy_output);
6204 });
6205}
6206
6207TEST_F(LazyOpsTest, TestHardSigmoidBackward) {
6208 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
6209 return torch::hardsigmoid(inputs[0]);
6210 };
6211 ForEachDevice([&](const torch::Device& device) {
6212 TestBackward(
6213 {torch::randn(
6214 {10},
6215 torch::TensorOptions(torch::kFloat)
6216 .device(DefaultDevice())
6217 .requires_grad(true))},
6218 device,
6219 testfn);
6220 });
6221}
6222
6223TEST_F(LazyOpsTest, TestSoftshrink) {
6224 torch::Tensor input = torch::randn(
6225 {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6226 torch::Tensor output = torch::softshrink(input);
6227 ForEachDevice([&](const torch::Device& device) {
6228 torch::Tensor lazy_input = CopyToDevice(input, device);
6229 torch::Tensor lazy_output = torch::softshrink(lazy_input);
6230 AllClose(output, lazy_output);
6231 });
6232}
6233
6234TEST_F(LazyOpsTest, TestHardtanh) {
6235 torch::Tensor input = torch::randn(
6236 {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6237 torch::Tensor output = torch::hardtanh(input);
6238 ForEachDevice([&](const torch::Device& device) {
6239 torch::Tensor lazy_input = CopyToDevice(input, device);
6240 torch::Tensor lazy_output = torch::hardtanh(lazy_input);
6241 AllClose(output, lazy_output);
6242 });
6243}
6244
6245TEST_F(LazyOpsTest, TestHardtanhInPlace) {
6246 torch::Tensor input = torch::randn(
6247 {10}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6248 ForEachDevice([&](const torch::Device& device) {
6249 torch::Tensor lazy_input = CopyToDevice(input, device);
6250 torch::Tensor output = torch::hardtanh_(input);
6251 torch::Tensor lazy_output = torch::hardtanh_(lazy_input);
6252 AllClose(output, lazy_output);
6253 AllClose(input, lazy_input);
6254 });
6255}
6256
6257TEST_F(LazyOpsTest, TestLeakyRelu) {
6258 torch::Tensor input = torch::rand(
6259 {2, 1, 4, 6},
6260 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6261 double negative_slope = 0.01;
6262 torch::Tensor output = torch::leaky_relu(input, negative_slope);
6263 ForEachDevice([&](const torch::Device& device) {
6264 torch::Tensor lazy_input = CopyToDevice(input, device);
6265 torch::Tensor lazy_output = torch::leaky_relu(lazy_input, negative_slope);
6266 AllClose(output, lazy_output);
6267 });
6268}
6269
6270TEST_F(LazyOpsTest, TestLeakyReluInPlace) {
6271 torch::Tensor input = torch::rand(
6272 {2, 1, 4, 6},
6273 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6274 double negative_slope = 0.01;
6275 ForEachDevice([&](const torch::Device& device) {
6276 torch::Tensor lazy_input = CopyToDevice(input, device);
6277 torch::Tensor output = torch::leaky_relu_(input, negative_slope);
6278 torch::Tensor lazy_output = torch::leaky_relu_(lazy_input, negative_slope);
6279 AllClose(output, lazy_output);
6280 AllClose(input, lazy_input);
6281 });
6282}
6283
6284TEST_F(LazyOpsTest, TestExp) {
6285 torch::Tensor a = torch::rand(
6286 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6287 torch::Tensor b = torch::exp(a);
6288 ForEachDevice([&](const torch::Device& device) {
6289 torch::Tensor lazy_a = CopyToDevice(a, device);
6290 torch::Tensor lazy_b = torch::exp(lazy_a);
6291 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6292 });
6293}
6294
6295TEST_F(LazyOpsTest, TestExpm1) {
6296 torch::Tensor a = torch::rand(
6297 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6298 torch::Tensor b = torch::expm1(a);
6299 ForEachDevice([&](const torch::Device& device) {
6300 torch::Tensor lazy_a = CopyToDevice(a, device);
6301 torch::Tensor lazy_b = torch::expm1(lazy_a);
6302 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6303 });
6304}
6305
6306TEST_F(LazyOpsTest, TestLog) {
6307 torch::Tensor a = torch::rand(
6308 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6309 torch::Tensor b = torch::log(a);
6310 ForEachDevice([&](const torch::Device& device) {
6311 torch::Tensor lazy_a = CopyToDevice(a, device);
6312 torch::Tensor lazy_b = torch::log(lazy_a);
6313 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6314 });
6315}
6316
6317TEST_F(LazyOpsTest, TestLog2) {
6318 torch::Tensor a = torch::rand(
6319 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6320 torch::Tensor b = torch::log2(a);
6321 ForEachDevice([&](const torch::Device& device) {
6322 torch::Tensor lazy_a = CopyToDevice(a, device);
6323 torch::Tensor lazy_b = torch::log2(lazy_a);
6324 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6325 });
6326}
6327
6328TEST_F(LazyOpsTest, TestLog10) {
6329 torch::Tensor a = torch::rand(
6330 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6331 torch::Tensor b = torch::log10(a);
6332 ForEachDevice([&](const torch::Device& device) {
6333 torch::Tensor lazy_a = CopyToDevice(a, device);
6334 torch::Tensor lazy_b = torch::log10(lazy_a);
6335 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6336 });
6337}
6338
6339TEST_F(LazyOpsTest, TestLog1p) {
6340 torch::Tensor a = torch::rand(
6341 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6342 torch::Tensor b = torch::log1p(a);
6343 ForEachDevice([&](const torch::Device& device) {
6344 torch::Tensor lazy_a = CopyToDevice(a, device);
6345 torch::Tensor lazy_b = torch::log1p(lazy_a);
6346 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6347 });
6348}
6349
6350TEST_F(LazyOpsTest, TestErf) {
6351 torch::Tensor a = torch::randn(
6352 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6353 torch::Tensor b = torch::erf(a);
6354 ForEachDevice([&](const torch::Device& device) {
6355 torch::Tensor lazy_a = CopyToDevice(a, device);
6356 torch::Tensor lazy_b = torch::erf(lazy_a);
6357 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6358 });
6359}
6360
6361TEST_F(LazyOpsTest, TestErfc) {
6362 torch::Tensor a = torch::randn(
6363 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6364 torch::Tensor b = torch::erfc(a);
6365 ForEachDevice([&](const torch::Device& device) {
6366 torch::Tensor lazy_a = CopyToDevice(a, device);
6367 torch::Tensor lazy_b = torch::erfc(lazy_a);
6368 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6369 });
6370}
6371
6372TEST_F(LazyOpsTest, TestErfinv) {
6373 torch::Tensor a = torch::rand(
6374 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6375 torch::Tensor b = torch::erfinv(a);
6376 ForEachDevice([&](const torch::Device& device) {
6377 torch::Tensor lazy_a = CopyToDevice(a, device);
6378 torch::Tensor lazy_b = torch::erfinv(lazy_a);
6379 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6380 });
6381}
6382
6383TEST_F(LazyOpsTest, TestSqrt) {
6384 torch::Tensor a = torch::abs(torch::rand(
6385 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6386 torch::Tensor b = torch::sqrt(a);
6387 ForEachDevice([&](const torch::Device& device) {
6388 torch::Tensor lazy_a = CopyToDevice(a, device);
6389 torch::Tensor lazy_b = torch::sqrt(lazy_a);
6390 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6391 });
6392}
6393
6394TEST_F(LazyOpsTest, TestRsqrt) {
6395 torch::Tensor a = torch::abs(torch::rand(
6396 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6397 torch::Tensor b = torch::rsqrt(a);
6398 ForEachDevice([&](const torch::Device& device) {
6399 torch::Tensor lazy_a = CopyToDevice(a, device);
6400 torch::Tensor lazy_b = torch::rsqrt(lazy_a);
6401 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6402 });
6403}
6404
6405TEST_F(LazyOpsTest, TestReciprocal) {
6406 torch::Tensor a = torch::randn(
6407 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6408 torch::Tensor b = torch::reciprocal(a);
6409 ForEachDevice([&](const torch::Device& device) {
6410 torch::Tensor lazy_a = CopyToDevice(a, device);
6411 torch::Tensor lazy_b = torch::reciprocal(lazy_a);
6412 AllClose(b, lazy_b, /*rtol=*/1e-3, /*atol=*/1e-5);
6413 });
6414}
6415
6416TEST_F(LazyOpsTest, TestPowTensorScalar) {
6417 torch::Tensor base = torch::rand(
6418 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6419 torch::Scalar exponent = 4.09;
6420 torch::Tensor result = torch::pow(base, exponent);
6421 ForEachDevice([&](const torch::Device& device) {
6422 torch::Tensor lazy_base = CopyToDevice(base, device);
6423 torch::Tensor lazy_result = torch::pow(lazy_base, exponent);
6424 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6425 });
6426}
6427
6428TEST_F(LazyOpsTest, TestPowTensorScalarInPlace) {
6429 torch::Tensor base = torch::rand(
6430 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6431 torch::Scalar exponent = 4.09;
6432 ForEachDevice([&](const torch::Device& device) {
6433 torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
6434 torch::Tensor result = base.pow_(exponent);
6435 torch::Tensor lazy_result = lazy_base.pow_(exponent);
6436 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6437 AllClose(base, lazy_base, /*rtol=*/1e-3, /*atol=*/1e-5);
6438 });
6439}
6440
6441TEST_F(LazyOpsTest, TestPowTensorTensor) {
6442 torch::Tensor base = torch::abs(torch::rand(
6443 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6444 torch::Tensor exponent = torch::rand(
6445 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6446 torch::Tensor result = torch::pow(base, exponent);
6447 ForEachDevice([&](const torch::Device& device) {
6448 torch::Tensor lazy_base = CopyToDevice(base, device);
6449 torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6450 torch::Tensor lazy_result = torch::pow(lazy_base, lazy_exponent);
6451 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6452 });
6453}
6454
6455TEST_F(LazyOpsTest, TestPowTensorTensorInPlace) {
6456 torch::Tensor base = torch::abs(torch::rand(
6457 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6458 torch::Tensor exponent = torch::rand(
6459 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6460 ForEachDevice([&](const torch::Device& device) {
6461 torch::Tensor lazy_base = CopyToDevice(base.clone(), device);
6462 torch::Tensor result = base.pow_(exponent);
6463 torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6464 torch::Tensor lazy_result = lazy_base.pow_(lazy_exponent);
6465 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6466 AllClose(base, lazy_base, /*rtol=*/1e-3, /*atol=*/1e-5);
6467 });
6468}
6469
6470TEST_F(LazyOpsTest, TestPowTensorTensorBroadcast) {
6471 torch::Tensor base = torch::abs(torch::rand(
6472 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6473 torch::Tensor exponent = torch::rand(
6474 {4, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6475 torch::Tensor result = torch::pow(base, exponent);
6476 ForEachDevice([&](const torch::Device& device) {
6477 torch::Tensor lazy_base = CopyToDevice(base, device);
6478 torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6479 torch::Tensor lazy_result = torch::pow(lazy_base, lazy_exponent);
6480 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6481 });
6482}
6483
6484TEST_F(LazyOpsTest, TestPowScalarTensor) {
6485 torch::Scalar base = 3.5;
6486 torch::Tensor exponent = torch::rand({4, 2});
6487 torch::Tensor result = torch::pow(base, exponent);
6488 ForEachDevice([&](const torch::Device& device) {
6489 torch::Tensor lazy_exponent = CopyToDevice(exponent, device);
6490 torch::Tensor lazy_result = torch::pow(base, lazy_exponent);
6491 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6492 });
6493}
6494
6495TEST_F(LazyOpsTest, TestPowIntExponent) {
6496 torch::Tensor base = torch::abs(torch::rand(
6497 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())));
6498 torch::Scalar exponent = 3;
6499 torch::Tensor result = torch::pow(base, exponent);
6500 ForEachDevice([&](const torch::Device& device) {
6501 torch::Tensor lazy_base = CopyToDevice(base, device);
6502 torch::Tensor lazy_result = torch::pow(lazy_base, exponent);
6503 AllClose(result, lazy_result, /*rtol=*/1e-3, /*atol=*/1e-5);
6504 });
6505}
6506
6507TEST_F(LazyOpsTest, TestFmodScalar) {
6508 torch::Tensor a =
6509 torch::rand(
6510 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6511 100.0;
6512 torch::Scalar divisor = 2.0;
6513 torch::Tensor b = torch::fmod(a, divisor);
6514 ForEachDevice([&](const torch::Device& device) {
6515 torch::Tensor lazy_a = CopyToDevice(a, device);
6516 torch::Tensor lazy_b = torch::fmod(lazy_a, divisor);
6517 AllClose(b, lazy_b);
6518 });
6519}
6520
6521TEST_F(LazyOpsTest, TestFmodScalarInPlace) {
6522 torch::Scalar divisor = 2.0;
6523 ForEachDevice([&](const torch::Device& device) {
6524 torch::Tensor a =
6525 torch::rand(
6526 {2, 2},
6527 torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6528 100.0;
6529 torch::Tensor lazy_a = CopyToDevice(a, device);
6530 torch::Tensor b = a.fmod_(divisor);
6531 torch::Tensor lazy_b = lazy_a.fmod_(divisor);
6532 AllClose(b, lazy_b);
6533 AllClose(a, lazy_a);
6534 });
6535}
6536
6537TEST_F(LazyOpsTest, TestFmodTensor) {
6538 torch::Tensor a =
6539 torch::rand(
6540 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6541 100.0;
6542 torch::Tensor b =
6543 torch::rand(
6544 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6545 10.0;
6546 torch::Tensor c = torch::fmod(a, b);
6547 ForEachDevice([&](const torch::Device& device) {
6548 torch::Tensor lazy_a = CopyToDevice(a, device);
6549 torch::Tensor lazy_b = CopyToDevice(b, device);
6550 torch::Tensor lazy_c = torch::fmod(lazy_a, lazy_b);
6551 AllClose(c, lazy_c);
6552 });
6553}
6554
6555TEST_F(LazyOpsTest, TestFmodTensorInPlace) {
6556 torch::Tensor b =
6557 torch::rand(
6558 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6559 10.0;
6560 ForEachDevice([&](const torch::Device& device) {
6561 torch::Tensor a =
6562 torch::rand(
6563 {2, 2},
6564 torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6565 100.0;
6566 torch::Tensor lazy_a = CopyToDevice(a, device);
6567 torch::Tensor c = a.fmod_(b);
6568 torch::Tensor lazy_b = CopyToDevice(b, device);
6569 torch::Tensor lazy_c = lazy_a.fmod_(lazy_b);
6570 AllClose(c, lazy_c);
6571 AllClose(a, lazy_a);
6572 });
6573}
6574
6575TEST_F(LazyOpsTest, TestRemainderScalar) {
6576 torch::Tensor a =
6577 torch::randn(
6578 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6579 100.0;
6580 torch::Scalar divisor = -2.0;
6581 torch::Tensor b = torch::remainder(a, divisor);
6582 ForEachDevice([&](const torch::Device& device) {
6583 torch::Tensor lazy_a = CopyToDevice(a, device);
6584 torch::Tensor lazy_b = torch::remainder(lazy_a, divisor);
6585 AllClose(b, lazy_b);
6586 });
6587}
6588
6589TEST_F(LazyOpsTest, TestRemainderScalarInPlace) {
6590 torch::Scalar divisor = -2.0;
6591 ForEachDevice([&](const torch::Device& device) {
6592 torch::Tensor a =
6593 torch::randn(
6594 {2, 2},
6595 torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6596 100.0;
6597 torch::Tensor lazy_a = CopyToDevice(a, device);
6598 torch::Tensor b = a.remainder_(divisor);
6599 torch::Tensor lazy_b = lazy_a.remainder_(divisor);
6600 AllClose(b, lazy_b);
6601 AllClose(a, lazy_a);
6602 });
6603}
6604
6605TEST_F(LazyOpsTest, TestRemainderTensor) {
6606 torch::Tensor a =
6607 torch::randn(
6608 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6609 100.0;
6610 torch::Tensor b =
6611 torch::randn(
6612 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6613 10.0;
6614 torch::Tensor c = torch::remainder(a, b);
6615 ForEachDevice([&](const torch::Device& device) {
6616 torch::Tensor lazy_a = CopyToDevice(a, device);
6617 torch::Tensor lazy_b = CopyToDevice(b, device);
6618 torch::Tensor lazy_c = torch::remainder(lazy_a, lazy_b);
6619 AllClose(c, lazy_c, /*rtol=*/1e-4, /*atol=*/1e-6);
6620 });
6621}
6622
6623TEST_F(LazyOpsTest, TestRemainderTensorInPlace) {
6624 torch::Tensor b =
6625 torch::randn(
6626 {2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6627 10.0;
6628 ForEachDevice([&](const torch::Device& device) {
6629 torch::Tensor a =
6630 torch::randn(
6631 {2, 2},
6632 torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
6633 100.0;
6634 torch::Tensor lazy_a = CopyToDevice(a, device);
6635 torch::Tensor c = a.remainder_(b);
6636 torch::Tensor lazy_b = CopyToDevice(b, device);
6637 torch::Tensor lazy_c = lazy_a.remainder_(lazy_b);
6638 AllClose(c, lazy_c, /*rtol=*/1e-4, /*atol=*/1e-6);
6639 AllClose(a, lazy_a, /*rtol=*/1e-4, /*atol=*/1e-6);
6640 });
6641}
6642
6643TEST_F(LazyOpsTest, TestWhere) {
6644 torch::Tensor a = torch::rand(
6645 {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6646 torch::Tensor b = torch::rand(
6647 {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6648 torch::Tensor c = torch::empty(
6649 {3, 3}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
6650 for (int i = 0; i < 3; ++i) {
6651 for (int j = 0; j < 3; ++j) {
6652 c[i][j] = i == j;
6653 }
6654 }
6655 torch::Tensor d = torch::where(c, a, b);
6656 ForEachDevice([&](const torch::Device& device) {
6657 torch::Tensor lazy_a = CopyToDevice(a, device);
6658 torch::Tensor lazy_b = CopyToDevice(b, device);
6659 torch::Tensor lazy_c = CopyToDevice(c, device);
6660 torch::Tensor lazy_d = torch::where(lazy_c, lazy_a, lazy_b);
6661 AllClose(d, lazy_d);
6662 });
6663}
6664
6665TEST_F(LazyOpsTest, TestWhereBroadcast) {
6666 torch::Tensor a = torch::rand(
6667 {3, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6668 torch::Tensor b = torch::zeros(
6669 {}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6670 torch::Tensor c = torch::empty(
6671 {3, 3}, torch::TensorOptions(torch::kByte).device(DefaultDevice()));
6672 for (int i = 0; i < 3; ++i) {
6673 for (int j = 0; j < 3; ++j) {
6674 c[i][j] = i == j;
6675 }
6676 }
6677 torch::Tensor d = torch::where(c, a, b);
6678 ForEachDevice([&](const torch::Device& device) {
6679 torch::Tensor lazy_a = CopyToDevice(a, device);
6680 torch::Tensor lazy_b = CopyToDevice(b, device);
6681 torch::Tensor lazy_c = CopyToDevice(c, device);
6682 torch::Tensor lazy_d = torch::where(lazy_c, lazy_a, lazy_b);
6683 AllClose(d, lazy_d);
6684 });
6685}
6686
6687TEST_F(LazyOpsTest, TestThreshold) {
6688 torch::Tensor input = torch::rand(
6689 {2, 1, 4, 6},
6690 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6691 float threshold = 0.4;
6692 float value = 20;
6693 torch::Tensor output = torch::threshold(input, threshold, value);
6694 ForEachDevice([&](const torch::Device& device) {
6695 torch::Tensor lazy_input = CopyToDevice(input, device);
6696 torch::Tensor lazy_output = torch::threshold(lazy_input, threshold, value);
6697 AllClose(output, lazy_output);
6698 });
6699}
6700
6701TEST_F(LazyOpsTest, TestThresholdBackward) {
6702 float threshold = 0.4;
6703 float value = 20;
6704
6705 auto testFunction =
6706 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
6707 return torch::threshold(inputs[0], threshold, value);
6708 };
6709
6710 ForEachDevice([&](const torch::Device& device) {
6711 TestBackward(
6712 {torch::rand(
6713 {2, 1, 4, 6},
6714 torch::TensorOptions(torch::kFloat)
6715 .device(DefaultDevice())
6716 .requires_grad(true))},
6717 device,
6718 testFunction);
6719 });
6720}
6721
6722TEST_F(LazyOpsTest, TestThresholdInPlace) {
6723 torch::Tensor input = torch::rand(
6724 {2, 1, 4, 6},
6725 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6726 torch::Tensor output = input.clone();
6727 float threshold = 0.4;
6728 float value = 20;
6729 torch::threshold_(output, threshold, value);
6730 ForEachDevice([&](const torch::Device& device) {
6731 torch::Tensor lazy_output = CopyToDevice(input, device);
6732 torch::threshold_(lazy_output, threshold, value);
6733 AllClose(output, lazy_output);
6734 });
6735}
6736
6737TEST_F(LazyOpsTest, TestElu) {
6738 torch::Tensor input = torch::rand(
6739 {2, 1, 4, 6},
6740 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6741 torch::Scalar alpha = 0.5;
6742 torch::Scalar scale = 2.5;
6743 torch::Scalar input_scale = 1.5;
6744 torch::Tensor output = torch::elu(input, alpha, scale, input_scale);
6745 ForEachDevice([&](const torch::Device& device) {
6746 torch::Tensor lazy_input = CopyToDevice(input, device);
6747 torch::Tensor lazy_output =
6748 torch::elu(lazy_input, alpha, scale, input_scale);
6749 AllClose(output, lazy_output);
6750 });
6751}
6752
6753TEST_F(LazyOpsTest, TestEluInPlace) {
6754 torch::Tensor input = torch::rand(
6755 {2, 1, 4, 6},
6756 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6757 torch::Scalar alpha = 0.5;
6758 torch::Scalar scale = 2.5;
6759 torch::Scalar input_scale = 1.5;
6760 ForEachDevice([&](const torch::Device& device) {
6761 torch::Tensor lazy_input = CopyToDevice(input, device);
6762 torch::Tensor output = torch::elu_(input, alpha, scale, input_scale);
6763 torch::Tensor lazy_output =
6764 torch::elu_(lazy_input, alpha, scale, input_scale);
6765 AllClose(output, lazy_output);
6766 AllClose(input, lazy_input);
6767 });
6768}
6769
6770TEST_F(LazyOpsTest, TestSelu) {
6771 torch::Tensor input = torch::rand(
6772 {2, 1, 4, 6},
6773 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6774 torch::Tensor output = torch::selu(input);
6775 ForEachDevice([&](const torch::Device& device) {
6776 torch::Tensor lazy_input = CopyToDevice(input, device);
6777 torch::Tensor lazy_output = torch::selu(lazy_input);
6778 AllClose(output, lazy_output);
6779 });
6780}
6781
6782TEST_F(LazyOpsTest, TestSeluInPlace) {
6783 torch::Tensor input = torch::rand(
6784 {2, 1, 4, 6},
6785 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6786 ForEachDevice([&](const torch::Device& device) {
6787 torch::Tensor lazy_input = CopyToDevice(input, device);
6788 torch::Tensor output = torch::selu_(input);
6789 torch::Tensor lazy_output = torch::selu_(lazy_input);
6790 AllClose(output, lazy_output);
6791 AllClose(input, lazy_input);
6792 });
6793}
6794
6795TEST_F(LazyOpsTest, TestCelu) {
6796 torch::Tensor input = torch::rand(
6797 {2, 1, 4, 6},
6798 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6799 torch::Scalar alpha = 2.5;
6800 torch::Tensor output = torch::celu(input, alpha);
6801 ForEachDevice([&](const torch::Device& device) {
6802 torch::Tensor lazy_input = CopyToDevice(input, device);
6803 torch::Tensor lazy_output = torch::celu(lazy_input, alpha);
6804 AllClose(output, lazy_output);
6805 });
6806}
6807
6808TEST_F(LazyOpsTest, TestCeluInPlace) {
6809 torch::Tensor input = torch::rand(
6810 {2, 1, 4, 6},
6811 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6812 torch::Scalar alpha = 2.5;
6813 ForEachDevice([&](const torch::Device& device) {
6814 torch::Tensor lazy_input = CopyToDevice(input, device);
6815 torch::Tensor output = torch::celu_(input, alpha);
6816 torch::Tensor lazy_output = torch::celu_(lazy_input, alpha);
6817 AllClose(output, lazy_output);
6818 AllClose(input, lazy_input);
6819 });
6820}
6821
6822TEST_F(LazyOpsTest, TestGelu) {
6823 torch::Tensor input = torch::rand(
6824 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6825 torch::Tensor output = torch::gelu(input);
6826 ForEachDevice([&](const torch::Device& device) {
6827 torch::Tensor lazy_input = CopyToDevice(input, device);
6828 torch::Tensor lazy_output = torch::gelu(lazy_input);
6829 AllClose(output, lazy_output);
6830 });
6831}
6832
6833TEST_F(LazyOpsTest, TestAddMatMul) {
6834 int in_channels = 32;
6835 int out_channels = 320;
6836 int labels = 50;
6837 torch::Tensor input = torch::rand(
6838 {in_channels, out_channels},
6839 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6840 torch::Tensor weight = torch::rand(
6841 {out_channels, labels},
6842 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6843 torch::Tensor bias = torch::rand(
6844 {labels}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6845 // Test beta != 1. through the CPU interop.
6846 for (double beta : {1., 2.}) {
6847 torch::Tensor output = torch::addmm(bias, input, weight, /*beta=*/beta);
6848 ForEachDevice([&](const torch::Device& device) {
6849 torch::Tensor lazy_input = CopyToDevice(input, device);
6850 torch::Tensor lazy_weight = CopyToDevice(weight, device);
6851 torch::Tensor lazy_bias = CopyToDevice(bias, device);
6852 torch::Tensor lazy_output =
6853 torch::addmm(lazy_bias, lazy_input, lazy_weight, /*beta=*/beta);
6854 AllClose(output, lazy_output);
6855 });
6856 }
6857}
6858
6859TEST_F(LazyOpsTest, TestEmbedding) {
6860 torch::Tensor a = torch::rand(
6861 {32, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6862 torch::Tensor i = torch::randint(
6863 0,
6864 31,
6865 {3, 4},
6866 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6867 torch::Tensor b = torch::embedding(
6868 a,
6869 i,
6870 /*padding_idx=*/0,
6871 /*scale_grad_by_freq=*/false,
6872 /*sparse=*/false);
6873 ForEachDevice([&](const torch::Device& device) {
6874 torch::Tensor lazy_a = CopyToDevice(a, device);
6875 torch::Tensor lazy_i = CopyToDevice(i, device);
6876 torch::Tensor lazy_b = torch::embedding(
6877 lazy_a,
6878 lazy_i,
6879 /*padding_idx=*/0,
6880 /*scale_grad_by_freq=*/false,
6881 /*sparse=*/false);
6882 AllClose(b, lazy_b);
6883 });
6884}
6885
6886TEST_F(LazyOpsTest, TestOneHot) {
6887 int num_classes = 5;
6888 torch::Tensor input = torch::randint(
6889 0,
6890 num_classes,
6891 {10},
6892 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
6893 torch::Tensor output = torch::one_hot(input, num_classes);
6894 ForEachDevice([&](const torch::Device& device) {
6895 torch::Tensor lazy_input = CopyToDevice(input, device);
6896 torch::Tensor lazy_output = torch::one_hot(lazy_input, num_classes);
6897 AllEqual(output, lazy_output);
6898 });
6899}
6900
6901TEST_F(LazyOpsTest, TestTranspose) {
6902 torch::Tensor input = torch::rand(
6903 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6904 torch::Tensor output = torch::t(input);
6905 ForEachDevice([&](const torch::Device& device) {
6906 torch::Tensor lazy_input = CopyToDevice(input, device);
6907 torch::Tensor lazy_output = torch::t(lazy_input);
6908 AllClose(output, lazy_output);
6909 });
6910}
6911
6912TEST_F(LazyOpsTest, TestTransposeInPlace) {
6913 torch::Tensor input = torch::rand(
6914 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6915 ForEachDevice([&](const torch::Device& device) {
6916 torch::Tensor lazy_input = CopyToDevice(input, device);
6917 torch::Tensor output = input.t_();
6918 torch::Tensor lazy_output = lazy_input.t_();
6919 EXPECT_EQ(lazy_output.sizes(), output.sizes());
6920 AllClose(output, lazy_output);
6921 AllClose(input, lazy_input);
6922 });
6923}
6924
6925TEST_F(LazyOpsTest, TestReshape) {
6926 torch::Tensor input = torch::rand(
6927 {32, 20, 4, 4},
6928 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6929 torch::Tensor output = torch::reshape(input, {-1, 320});
6930 ForEachDevice([&](const torch::Device& device) {
6931 torch::Tensor lazy_input = CopyToDevice(input, device);
6932 torch::Tensor lazy_output = torch::reshape(lazy_input, {-1, 320});
6933 AllClose(output, lazy_output);
6934 });
6935}
6936
6937TEST_F(LazyOpsTest, TestResize) {
6938 // Testing a resize_() with target size bigger than original size is not
6939 // possible, as we fill with zeros, while pytorch fills with random garbage.
6940 torch::Tensor input = torch::rand(
6941 {2, 2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6942 torch::Tensor saved_input = input.clone();
6943 input.resize_({3, 3});
6944 ForEachDevice([&](const torch::Device& device) {
6945 torch::Tensor lazy_input = CopyToDevice(saved_input, device);
6946 lazy_input.resize_({3, 3});
6947 AllClose(input, lazy_input);
6948 });
6949}
6950
6951TEST_F(LazyOpsTest, TestViewResize) {
6952 torch::Tensor input = torch::zeros(
6953 {8, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6954 torch::Tensor saved_input = input.clone();
6955 torch::Tensor output = input.view({4, 4});
6956 output.resize_({3, 3});
6957 ForEachDevice([&](const torch::Device& device) {
6958 torch::Tensor lazy_input = CopyToDevice(saved_input, device);
6959 torch::Tensor lazy_output = lazy_input.view({4, 4});
6960 lazy_output.resize_({3, 3});
6961 AllClose(input, lazy_input);
6962 AllClose(output, lazy_output);
6963 });
6964}
6965
6966TEST_F(LazyOpsTest, TestView) {
6967 torch::Tensor input = torch::rand(
6968 {32, 20, 4, 4},
6969 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6970 torch::Tensor output = input.view({-1, 320});
6971 ForEachDevice([&](const torch::Device& device) {
6972 torch::Tensor lazy_input = CopyToDevice(input, device);
6973 torch::Tensor lazy_output = lazy_input.view({-1, 320});
6974 AllClose(output, lazy_output);
6975 });
6976}
6977
6978TEST_F(LazyOpsTest, TestViewMod) {
6979 torch::Tensor input = torch::zeros(
6980 {32, 20, 4, 4},
6981 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6982 torch::Tensor one = torch::tensor(
6983 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6984 torch::Tensor output = input.view({-1, 320});
6985 output.add_(one, 1.0);
6986 input.add_(one, 1.0);
6987 ForEachDevice([&](const torch::Device& device) {
6988 torch::Tensor xinput = torch::zeros(
6989 {32, 20, 4, 4},
6990 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
6991 torch::Tensor lazy_input = CopyToDevice(xinput, device);
6992 torch::Tensor lazy_one = CopyToDevice(one, device);
6993 torch::Tensor lazy_output = lazy_input.view({-1, 320});
6994 lazy_output.add_(lazy_one, 1.0);
6995 lazy_input.add_(lazy_one, 1.0);
6996 AllClose(output, lazy_output);
6997 AllClose(input, lazy_input);
6998 });
6999}
7000
7001TEST_F(LazyOpsTest, TestViewModComplex) {
7002 torch::Tensor input = torch::zeros(
7003 {32, 20, 4, 4},
7004 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7005 torch::Tensor one = torch::tensor(
7006 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7007 torch::Tensor output1 = input.view({-1, 320});
7008 output1.add_(one, 1.0);
7009 torch::Tensor output2 = input.view({-1, 160});
7010 output2.add_(one, 1.0);
7011 ForEachDevice([&](const torch::Device& device) {
7012 torch::Tensor xinput = torch::zeros(
7013 {32, 20, 4, 4},
7014 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7015 torch::Tensor lazy_input = CopyToDevice(xinput, device);
7016 torch::Tensor lazy_one = CopyToDevice(one, device);
7017 torch::Tensor lazy_output1 = lazy_input.view({-1, 320});
7018 lazy_output1.add_(lazy_one, 1.0);
7019 torch::Tensor lazy_output2 = lazy_input.view({-1, 160});
7020 lazy_output2.add_(lazy_one, 1.0);
7021 AllClose(output1, lazy_output1);
7022 AllClose(output2, lazy_output2);
7023 });
7024}
7025
7026TEST_F(LazyOpsTest, TestViewOfViewMod) {
7027 torch::Tensor input = torch::zeros(
7028 {32, 20, 4, 4},
7029 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7030 torch::Tensor one = torch::tensor(
7031 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7032 torch::Tensor output1 = input.view({-1, 320});
7033 output1.add_(one, 1.0);
7034 torch::Tensor output2 = output1.view({-1, 160});
7035 output2.add_(one, 1.0);
7036 ForEachDevice([&](const torch::Device& device) {
7037 torch::Tensor xinput = torch::zeros(
7038 {32, 20, 4, 4},
7039 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7040 torch::Tensor lazy_input = CopyToDevice(xinput, device);
7041 torch::Tensor lazy_one = CopyToDevice(one, device);
7042 torch::Tensor lazy_output1 = lazy_input.view({-1, 320});
7043 lazy_output1.add_(lazy_one, 1.0);
7044 torch::Tensor lazy_output2 = lazy_output1.view({-1, 160});
7045 lazy_output2.add_(lazy_one, 1.0);
7046 AllClose(output1, lazy_output1);
7047 AllClose(output2, lazy_output2);
7048 });
7049}
7050
7051TEST_F(LazyOpsTest, TestViewSqueezeAddInPlace) {
7052 torch::Tensor input = torch::zeros(
7053 {2, 3, 1}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7054 std::vector<int64_t> view_size = {2, 3, 1, 1};
7055 int squeeze_dim = 2;
7056 torch::Tensor one = torch::tensor(
7057 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7058 ForEachDevice([&](const torch::Device& device) {
7059 torch::Tensor lazy_input = CopyToDevice(input, device);
7060 torch::Tensor output = input.view(view_size);
7061 output.squeeze_(squeeze_dim);
7062 output.add_(one, 1.0);
7063 torch::Tensor lazy_one = CopyToDevice(one, device);
7064 torch::Tensor lazy_output = lazy_input.view(view_size);
7065 lazy_output.squeeze_(squeeze_dim);
7066 lazy_output.add_(lazy_one, 1.0);
7067 AllClose(output, lazy_output);
7068 AllClose(input, lazy_input);
7069 });
7070}
7071
7072TEST_F(LazyOpsTest, TestUnsafeView) {
7073 torch::Tensor input = torch::rand(
7074 {32, 20, 4, 4},
7075 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7076 torch::Tensor output = torch::_unsafe_view(input, {-1, 320});
7077 ForEachDevice([&](const torch::Device& device) {
7078 torch::Tensor lazy_input = CopyToDevice(input, device);
7079 torch::Tensor lazy_output = torch::_unsafe_view(lazy_input, {-1, 320});
7080 AllClose(output, lazy_output);
7081 });
7082}
7083
7084TEST_F(LazyOpsTest, TestNarrow) {
7085 torch::Tensor a = torch::rand(
7086 {8, 10, 4, 4},
7087 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7088 for (int64_t dim : {1, -3}) {
7089 for (int64_t start : {2, -8}) {
7090 torch::Tensor b = a.narrow(dim, start, 6);
7091 ForEachDevice([&](const torch::Device& device) {
7092 torch::Tensor lazy_a = CopyToDevice(a, device);
7093 torch::Tensor lazy_b = lazy_a.narrow(dim, start, 6);
7094 AllClose(b, lazy_b);
7095 });
7096 }
7097 }
7098}
7099
7100TEST_F(LazyOpsTest, TestNarrowUpdate) {
7101 for (int64_t dim : {1, -2}) {
7102 for (int64_t start : {2, -6}) {
7103 torch::Tensor a = torch::rand(
7104 {3, 8, 3},
7105 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7106 torch::Tensor a_copy = a.clone();
7107 torch::Tensor b = torch::rand(
7108 {3, 4, 3},
7109 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7110 torch::Tensor c = a.narrow(dim, start, 4);
7111 c.add_(b, 1.0);
7112 ForEachDevice([&](const torch::Device& device) {
7113 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7114 torch::Tensor lazy_b = CopyToDevice(b, device);
7115 torch::Tensor lazy_c = lazy_a.narrow(dim, start, 4);
7116 lazy_c.add_(lazy_b, 1.0);
7117 AllClose(c, lazy_c);
7118 });
7119 }
7120 }
7121}
7122
7123TEST_F(LazyOpsTest, TestNarrowUpdateBaseCheck) {
7124 for (int64_t dim : {0, -2}) {
7125 for (int64_t start : {2, -6}) {
7126 torch::Tensor a = torch::zeros(
7127 {8, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7128 torch::Tensor a_copy = a.clone();
7129 torch::Tensor b = torch::ones(
7130 {4, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7131 torch::Tensor c = a.narrow(dim, start, 4);
7132 c.add_(b, 1.0);
7133 ForEachDevice([&](const torch::Device& device) {
7134 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7135 torch::Tensor lazy_b = CopyToDevice(b, device);
7136 torch::Tensor lazy_c = lazy_a.narrow(dim, start, 4);
7137 lazy_c.add_(lazy_b, 1.0);
7138 AllClose(a, lazy_a);
7139 });
7140 }
7141 }
7142}
7143
7144TEST_F(LazyOpsTest, TestNarrowUpdateTwoSlices) {
7145 for (int64_t dim : {0, -2}) {
7146 for (int64_t start0 : {2, -6}) {
7147 for (int64_t start1 : {6, -2}) {
7148 torch::Tensor a = torch::zeros(
7149 {8, 3},
7150 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7151 torch::Tensor a_copy = a.clone();
7152 torch::Tensor b = torch::ones(
7153 {2, 3},
7154 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7155 torch::Tensor c = b + 1;
7156 torch::Tensor d = a.narrow(dim, start0, 2);
7157 torch::Tensor e = a.narrow(dim, start1, 2);
7158 d.add_(b, 1.0);
7159 e.add_(c, 1.0);
7160 ForEachDevice([&](const torch::Device& device) {
7161 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7162 torch::Tensor lazy_b = CopyToDevice(b, device);
7163 torch::Tensor lazy_c = CopyToDevice(c, device);
7164 torch::Tensor lazy_d = lazy_a.narrow(dim, start0, 2);
7165 torch::Tensor lazy_e = lazy_a.narrow(dim, start1, 2);
7166 lazy_d.add_(lazy_b, 1.0);
7167 lazy_e.add_(lazy_c, 1.0);
7168 AllClose(d, lazy_d);
7169 AllClose(e, lazy_e);
7170 AllClose(a, lazy_a);
7171 });
7172 }
7173 }
7174 }
7175}
7176
7177TEST_F(LazyOpsTest, TestNarrowUpdateView) {
7178 for (int64_t dim : {0, -3}) {
7179 for (int64_t start : {2, -6}) {
7180 torch::Tensor a = torch::rand(
7181 {8, 2, 3},
7182 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7183 torch::Tensor a_copy = a.clone();
7184 torch::Tensor b = torch::rand(
7185 {4, 6}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7186 torch::Tensor c = a.narrow(dim, start, 4);
7187 torch::Tensor d = c.view({4, 6});
7188 d.add_(b, 1.0);
7189 ForEachDevice([&](const torch::Device& device) {
7190 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7191 torch::Tensor lazy_b = CopyToDevice(b, device);
7192 torch::Tensor lazy_c = lazy_a.narrow(dim, start, 4);
7193 torch::Tensor lazy_d = lazy_c.view({4, 6});
7194 lazy_d.add_(lazy_b, 1.0);
7195 AllClose(d, lazy_d);
7196 });
7197 }
7198 }
7199}
7200
7201TEST_F(LazyOpsTest, TestNarrowInNarrowUpdate) {
7202 for (int64_t dim : {1, -2}) {
7203 for (int64_t start0 : {1, -7}) {
7204 for (int64_t start1 : {1, -5}) {
7205 torch::Tensor a = torch::rand(
7206 {3, 8, 3},
7207 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7208 torch::Tensor a_copy = a.clone();
7209 torch::Tensor b = torch::rand(
7210 {3, 2, 3},
7211 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7212 torch::Tensor c = a.narrow(dim, start0, 6);
7213 torch::Tensor d = c.narrow(dim, start1, 2);
7214 d.add_(b, 1.0);
7215 ForEachDevice([&](const torch::Device& device) {
7216 torch::Tensor lazy_a = CopyToDevice(a_copy, device);
7217 torch::Tensor lazy_b = CopyToDevice(b, device);
7218 torch::Tensor lazy_c = lazy_a.narrow(dim, start0, 6);
7219 torch::Tensor lazy_d = lazy_c.narrow(dim, start1, 2);
7220 lazy_d.add_(lazy_b, 1.0);
7221 AllClose(a, lazy_a);
7222 });
7223 }
7224 }
7225 }
7226}
7227
7228TEST_F(LazyOpsTest, TestNarrowCopy) {
7229 for (int64_t dim : {1, -3}) {
7230 for (int64_t start : {2, -8}) {
7231 ForEachDevice([&](const torch::Device& device) {
7232 torch::Tensor input = torch::rand(
7233 {8, 10, 4, 4},
7234 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7235 torch::Tensor lazy_input = CopyToDevice(input, device);
7236 torch::Tensor result = input.narrow_copy(dim, start, 6);
7237 input.add_(1);
7238 torch::Tensor lazy_result = lazy_input.narrow_copy(dim, start, 6);
7239 lazy_input.add_(1);
7240 AllClose(result, lazy_result);
7241 });
7242 }
7243 }
7244}
7245
7246TEST_F(LazyOpsTest, TestViewAs) {
7247 torch::Tensor input = torch::rand(
7248 {32, 20, 4, 4},
7249 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7250 torch::Tensor empty = torch::empty({32, 320});
7251 torch::Tensor output = input.view_as(empty);
7252 ForEachDevice([&](const torch::Device& device) {
7253 torch::Tensor lazy_input = CopyToDevice(input, device);
7254 torch::Tensor lazy_empty = CopyToDevice(empty, device);
7255 torch::Tensor lazy_output = lazy_input.view_as(lazy_empty);
7256 AllClose(output, lazy_output);
7257 });
7258}
7259
7260TEST_F(LazyOpsTest, TestLogSoftmax) {
7261 torch::Tensor input = torch::rand(
7262 {5, 3, 4, 2},
7263 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7264 ForEachDevice([&](const torch::Device& device) {
7265 torch::Tensor lazy_input = CopyToDevice(input, device);
7266 int rank = input.dim();
7267 for (int dim = -rank; dim < rank; ++dim) {
7268 torch::Tensor output = torch::log_softmax(input, dim);
7269 torch::Tensor lazy_output = torch::log_softmax(lazy_input, dim);
7270 AllClose(output, lazy_output, /*rtol=*/1e-3);
7271 }
7272 });
7273}
7274
7275TEST_F(LazyOpsTest, TestLogSoftmaxCast) {
7276 torch::Tensor input = torch::rand(
7277 {5, 3, 4, 2},
7278 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7279 ForEachDevice([&](const torch::Device& device) {
7280 torch::Tensor lazy_input = CopyToDevice(input, device);
7281 int rank = input.dim();
7282 for (int dim = -rank; dim < rank; ++dim) {
7283 torch::Tensor output = torch::log_softmax(input, dim, torch::kDouble);
7284 torch::Tensor lazy_output =
7285 torch::log_softmax(lazy_input, dim, torch::kDouble);
7286 AllClose(output, lazy_output, /*rtol=*/1e-3);
7287 }
7288 });
7289}
7290
7291TEST_F(LazyOpsTest, TestLogSoftmaxWrapper) {
7292 torch::Tensor input = torch::rand(
7293 {10, 2, 6, 4},
7294 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7295 ForEachDevice([&](const torch::Device& device) {
7296 torch::Tensor lazy_input = CopyToDevice(input, device);
7297 int rank = input.dim();
7298 for (int dim = -rank; dim < rank; ++dim) {
7299 torch::Tensor output =
7300 torch::_log_softmax(input, dim, /*half_to_float=*/false);
7301 torch::Tensor lazy_output =
7302 torch::_log_softmax(lazy_input, dim, /*half_to_float=*/false);
7303 AllClose(output, lazy_output, /*rtol=*/1e-3);
7304 }
7305 });
7306}
7307
7308TEST_F(LazyOpsTest, TestSoftmax) {
7309 torch::Tensor input = torch::rand(
7310 {10, 2, 6, 4},
7311 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7312 ForEachDevice([&](const torch::Device& device) {
7313 torch::Tensor lazy_input = CopyToDevice(input, device);
7314 int rank = input.dim();
7315 for (int dim = -rank; dim < rank; ++dim) {
7316 torch::Tensor output = torch::softmax(input, dim);
7317 torch::Tensor lazy_output = torch::softmax(lazy_input, dim);
7318 AllClose(output, lazy_output, /*rtol=*/1e-3);
7319 }
7320 });
7321}
7322
7323TEST_F(LazyOpsTest, TestSoftmaxCast) {
7324 torch::Tensor input = torch::rand(
7325 {10, 2, 6, 4},
7326 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7327 ForEachDevice([&](const torch::Device& device) {
7328 torch::Tensor lazy_input = CopyToDevice(input, device);
7329 int rank = input.dim();
7330 for (int dim = -rank; dim < rank; ++dim) {
7331 torch::Tensor output = torch::softmax(input, dim, torch::kDouble);
7332 torch::Tensor lazy_output =
7333 torch::softmax(lazy_input, dim, torch::kDouble);
7334 AllClose(output, lazy_output, /*rtol=*/1e-3);
7335 }
7336 });
7337}
7338
7339TEST_F(LazyOpsTest, TestSoftmaxWrapper) {
7340 torch::Tensor input = torch::rand(
7341 {10, 2, 6, 4},
7342 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7343 ForEachDevice([&](const torch::Device& device) {
7344 torch::Tensor lazy_input = CopyToDevice(input, device);
7345 int rank = input.dim();
7346 for (int dim = -rank; dim < rank; ++dim) {
7347 torch::Tensor output =
7348 torch::_softmax(input, dim, /*half_to_float=*/false);
7349 torch::Tensor lazy_output =
7350 torch::_softmax(lazy_input, dim, /*half_to_float=*/false);
7351 AllClose(output, lazy_output, /*rtol=*/1e-3);
7352 }
7353 });
7354}
7355
7356TEST_F(LazyOpsTest, TestSoftplus) {
7357 torch::Tensor input = torch::rand(
7358 {2, 1, 4, 6},
7359 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7360 torch::Tensor output = torch::softplus(input);
7361 ForEachDevice([&](const torch::Device& device) {
7362 torch::Tensor lazy_input = CopyToDevice(input, device);
7363 torch::Tensor lazy_output = torch::softplus(lazy_input);
7364 AllClose(output, lazy_output, /*rtol=*/1e-4);
7365 });
7366}
7367
7368TEST_F(LazyOpsTest, TestMaxPool1D) {
7369 torch::Tensor input = torch::rand(
7370 {1, 16, 56}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7371 int kernel_size = 3;
7372 for (int stride = 1; stride <= 2; ++stride) {
7373 for (int padding = 0; padding <= 1; ++padding) {
7374 // Test ceil_mode=true through the CPU interop.
7375 for (bool ceil_mode : {false, true}) {
7376 // Test dilation through the CPU interop.
7377 for (int dilation = 1; dilation <= 2; ++dilation) {
7378 torch::Tensor output = torch::max_pool1d(
7379 input,
7380 /*kernel_size=*/{kernel_size},
7381 /*stride=*/{stride},
7382 /*padding=*/{padding},
7383 /*dilation=*/{dilation},
7384 /*ceil_mode=*/ceil_mode);
7385 ForEachDevice([&](const torch::Device& device) {
7386 torch::Tensor lazy_input = CopyToDevice(input, device);
7387 torch::Tensor lazy_output = torch::max_pool1d(
7388 lazy_input,
7389 /*kernel_size=*/{kernel_size},
7390 /*stride=*/{stride},
7391 /*padding=*/{padding},
7392 /*dilation=*/{dilation},
7393 /*ceil_mode=*/ceil_mode);
7394 AllClose(output, lazy_output);
7395 });
7396 }
7397 }
7398 }
7399 }
7400}
7401
7402TEST_F(LazyOpsTest, TestMaxPool2D) {
7403 torch::Tensor input = torch::rand(
7404 {1, 4, 14, 14},
7405 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7406 int kernel_size = 3;
7407 for (int stride = 1; stride <= 2; ++stride) {
7408 for (int padding = 0; padding <= 1; ++padding) {
7409 // Test ceil_mode=true through the CPU interop.
7410 for (bool ceil_mode : {false, true}) {
7411 // Test dilation through the CPU interop.
7412 for (int dilation = 1; dilation <= 2; ++dilation) {
7413 torch::Tensor output = torch::max_pool2d(
7414 input,
7415 /*kernel_size=*/{kernel_size, kernel_size},
7416 /*stride=*/{stride, stride},
7417 /*padding=*/{padding, padding},
7418 /*dilation=*/{dilation, dilation},
7419 /*ceil_mode=*/ceil_mode);
7420 ForEachDevice([&](const torch::Device& device) {
7421 torch::Tensor lazy_input = CopyToDevice(input, device);
7422 torch::Tensor lazy_output = torch::max_pool2d(
7423 lazy_input,
7424 /*kernel_size=*/{kernel_size, kernel_size},
7425 /*stride=*/{stride, stride},
7426 /*padding=*/{padding, padding},
7427 /*dilation=*/{dilation, dilation},
7428 /*ceil_mode=*/ceil_mode);
7429 AllClose(output, lazy_output);
7430 });
7431 }
7432 }
7433 }
7434 }
7435}
7436
7437TEST_F(LazyOpsTest, TestMaxPool2DWithIndices) {
7438 torch::Tensor input = torch::rand(
7439 {1, 4, 14, 14},
7440 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7441 int kernel_size = 3;
7442 for (int stride = 1; stride <= 2; ++stride) {
7443 for (int padding = 0; padding <= 1; ++padding) {
7444 // Test ceil_mode=true through the CPU interop.
7445 for (bool ceil_mode : {false, true}) {
7446 // Test dilation through the CPU interop.
7447 for (int dilation = 1; dilation <= 2; ++dilation) {
7448 auto outputs = torch::max_pool2d_with_indices(
7449 input,
7450 /*kernel_size=*/{kernel_size, kernel_size},
7451 /*stride=*/{stride, stride},
7452 /*padding=*/{padding, padding},
7453 /*dilation=*/{dilation, dilation},
7454 /*ceil_mode=*/ceil_mode);
7455 ForEachDevice([&](const torch::Device& device) {
7456 torch::Tensor lazy_input = CopyToDevice(input, device);
7457 auto lazy_outputs = torch::max_pool2d_with_indices(
7458 lazy_input,
7459 /*kernel_size=*/{kernel_size, kernel_size},
7460 /*stride=*/{stride, stride},
7461 /*padding=*/{padding, padding},
7462 /*dilation=*/{dilation, dilation},
7463 /*ceil_mode=*/ceil_mode);
7464 AllClose(std::get<0>(outputs), std::get<0>(lazy_outputs));
7465 AllClose(std::get<1>(outputs), std::get<1>(lazy_outputs));
7466 });
7467 }
7468 }
7469 }
7470 }
7471}
7472
7473TEST_F(LazyOpsTest, TestMaxPool2DNonSquare) {
7474 torch::Tensor input = torch::rand(
7475 {1, 4, 14, 14},
7476 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7477 int kernel_size = 4;
7478 for (int stride = 1; stride <= 2; ++stride) {
7479 for (int padding = 0; padding <= 1; ++padding) {
7480 // Test ceil_mode=true through the CPU interop.
7481 for (bool ceil_mode : {false, true}) {
7482 // Test dilation through the CPU interop.
7483 for (int dilation = 1; dilation <= 2; ++dilation) {
7484 torch::Tensor output = torch::max_pool2d(
7485 input,
7486 /*kernel_size=*/{kernel_size, kernel_size + 1},
7487 /*stride=*/{stride, stride + 1},
7488 /*padding=*/{padding, padding + 1},
7489 /*dilation=*/{dilation, dilation},
7490 /*ceil_mode=*/ceil_mode);
7491 ForEachDevice([&](const torch::Device& device) {
7492 torch::Tensor lazy_input = CopyToDevice(input, device);
7493 torch::Tensor lazy_output = torch::max_pool2d(
7494 lazy_input,
7495 /*kernel_size=*/{kernel_size, kernel_size + 1},
7496 /*stride=*/{stride, stride + 1},
7497 /*padding=*/{padding, padding + 1},
7498 /*dilation=*/{dilation, dilation},
7499 /*ceil_mode=*/ceil_mode);
7500 AllClose(output, lazy_output);
7501 });
7502 }
7503 }
7504 }
7505 }
7506}
7507
7508TEST_F(LazyOpsTest, TestMaxPool3D) {
7509 torch::Tensor input = torch::rand(
7510 {1, 1, 8, 8, 8},
7511 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7512 int kernel_size = 3;
7513 for (int stride = 1; stride <= 2; ++stride) {
7514 for (int padding = 0; padding <= 1; ++padding) {
7515 // Test ceil_mode=true through the CPU interop.
7516 for (bool ceil_mode : {false, true}) {
7517 // Test dilation through the CPU interop.
7518 for (int dilation = 1; dilation <= 2; ++dilation) {
7519 torch::Tensor output = torch::max_pool3d(
7520 input,
7521 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7522 /*stride=*/{stride, stride, stride},
7523 /*padding=*/{padding, padding, padding},
7524 /*dilation=*/{dilation, dilation, dilation},
7525 /*ceil_mode=*/ceil_mode);
7526 ForEachDevice([&](const torch::Device& device) {
7527 torch::Tensor lazy_input = CopyToDevice(input, device);
7528 torch::Tensor lazy_output = torch::max_pool3d(
7529 lazy_input,
7530 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7531 /*stride=*/{stride, stride, stride},
7532 /*padding=*/{padding, padding, padding},
7533 /*dilation=*/{dilation, dilation, dilation},
7534 /*ceil_mode=*/ceil_mode);
7535 AllClose(output, lazy_output);
7536 });
7537 }
7538 }
7539 }
7540 }
7541}
7542
7543TEST_F(LazyOpsTest, TestMaxPool3DWithIndices) {
7544 torch::Tensor input = torch::rand(
7545 {1, 1, 8, 8, 8},
7546 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7547 int kernel_size = 3;
7548 for (int stride = 1; stride <= 2; ++stride) {
7549 for (int padding = 0; padding <= 1; ++padding) {
7550 // Test ceil_mode=true through the CPU interop.
7551 for (bool ceil_mode : {false, true}) {
7552 // Test dilation through the CPU interop.
7553 for (int dilation = 1; dilation <= 2; ++dilation) {
7554 auto outputs = torch::max_pool3d_with_indices(
7555 input,
7556 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7557 /*stride=*/{stride, stride, stride},
7558 /*padding=*/{padding, padding, padding},
7559 /*dilation=*/{dilation, dilation, dilation},
7560 /*ceil_mode=*/ceil_mode);
7561 ForEachDevice([&](const torch::Device& device) {
7562 torch::Tensor lazy_input = CopyToDevice(input, device);
7563 auto lazy_outputs = torch::max_pool3d_with_indices(
7564 lazy_input,
7565 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7566 /*stride=*/{stride, stride, stride},
7567 /*padding=*/{padding, padding, padding},
7568 /*dilation=*/{dilation, dilation, dilation},
7569 /*ceil_mode=*/ceil_mode);
7570
7571 AllClose(std::get<0>(outputs), std::get<0>(lazy_outputs));
7572 AllClose(std::get<1>(outputs), std::get<1>(lazy_outputs));
7573 });
7574 }
7575 }
7576 }
7577 }
7578}
7579
7580TEST_F(LazyOpsTest, TestMaxPool3DIncompleteAttributes) {
7581 torch::Tensor input = torch::rand(
7582 {1, 1, 8, 8, 8},
7583 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7584 int kernel_size = 3;
7585 for (int stride = 1; stride <= 2; ++stride) {
7586 for (int padding = 0; padding <= 1; ++padding) {
7587 // Test ceil_mode=true through the CPU interop.
7588 for (bool ceil_mode : {false, true}) {
7589 // Test dilation through the CPU interop.
7590 for (int dilation = 1; dilation <= 2; ++dilation) {
7591 torch::Tensor output = torch::max_pool3d(
7592 input,
7593 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7594 /*stride=*/{},
7595 /*padding=*/{padding},
7596 /*dilation=*/{dilation, dilation, dilation},
7597 /*ceil_mode=*/ceil_mode);
7598 ForEachDevice([&](const torch::Device& device) {
7599 torch::Tensor lazy_input = CopyToDevice(input, device);
7600 torch::Tensor lazy_output = torch::max_pool3d(
7601 lazy_input,
7602 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7603 /*stride=*/{},
7604 /*padding=*/{padding},
7605 /*dilation=*/{dilation, dilation, dilation},
7606 /*ceil_mode=*/ceil_mode);
7607 AllClose(output, lazy_output);
7608 });
7609 }
7610 }
7611 }
7612 }
7613}
7614
7615TEST_F(LazyOpsTest, TestMaxPool3DNonSquare) {
7616 torch::Tensor input = torch::rand(
7617 {1, 1, 8, 8, 8},
7618 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7619 int kernel_size = 4;
7620 for (int stride = 1; stride <= 2; ++stride) {
7621 for (int padding = 0; padding <= 1; ++padding) {
7622 // Test ceil_mode=true through the CPU interop.
7623 for (bool ceil_mode : {false, true}) {
7624 // Test dilation through the CPU interop.
7625 for (int dilation = 1; dilation <= 2; ++dilation) {
7626 torch::Tensor output = torch::max_pool3d(
7627 input,
7628 /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7629 /*stride=*/{stride, stride + 1, stride},
7630 /*padding=*/{padding, padding + 1, padding},
7631 /*dilation=*/{dilation, dilation, dilation},
7632 /*ceil_mode=*/ceil_mode);
7633 ForEachDevice([&](const torch::Device& device) {
7634 torch::Tensor lazy_input = CopyToDevice(input, device);
7635 torch::Tensor lazy_output = torch::max_pool3d(
7636 lazy_input,
7637 /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7638 /*stride=*/{stride, stride + 1, stride},
7639 /*padding=*/{padding, padding + 1, padding},
7640 /*dilation=*/{dilation, dilation, dilation},
7641 /*ceil_mode=*/ceil_mode);
7642 AllClose(output, lazy_output);
7643 });
7644 }
7645 }
7646 }
7647 }
7648}
7649
7650TEST_F(LazyOpsTest, TestMaxPool2DNoBatch) {
7651 torch::Tensor input = torch::rand(
7652 {4, 14, 14}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7653 int kernel_size = 3;
7654 for (int stride = 1; stride <= 2; ++stride) {
7655 for (int padding = 0; padding <= 1; ++padding) {
7656 // Test ceil_mode=true through the CPU interop.
7657 for (bool ceil_mode : {false, true}) {
7658 // Test dilation through the CPU interop.
7659 for (int dilation = 1; dilation <= 2; ++dilation) {
7660 torch::Tensor output = torch::max_pool2d(
7661 input,
7662 /*kernel_size=*/{kernel_size, kernel_size},
7663 /*stride=*/{stride, stride},
7664 /*padding=*/{padding, padding},
7665 /*dilation=*/{dilation, dilation},
7666 /*ceil_mode=*/ceil_mode);
7667 ForEachDevice([&](const torch::Device& device) {
7668 torch::Tensor lazy_input = CopyToDevice(input, device);
7669 torch::Tensor lazy_output = torch::max_pool2d(
7670 lazy_input,
7671 /*kernel_size=*/{kernel_size, kernel_size},
7672 /*stride=*/{stride, stride},
7673 /*padding=*/{padding, padding},
7674 /*dilation=*/{dilation, dilation},
7675 /*ceil_mode=*/ceil_mode);
7676 AllClose(output, lazy_output);
7677 });
7678 }
7679 }
7680 }
7681 }
7682}
7683
7684TEST_F(LazyOpsTest, TestMaxPool3DNoBatch) {
7685 torch::Tensor input = torch::rand(
7686 {1, 8, 8, 8},
7687 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7688 int kernel_size = 3;
7689 for (int stride = 1; stride <= 2; ++stride) {
7690 for (int padding = 0; padding <= 1; ++padding) {
7691 // Test ceil_mode=true through the CPU interop.
7692 for (bool ceil_mode : {false, true}) {
7693 // Test dilation through the CPU interop.
7694 for (int dilation = 1; dilation <= 2; ++dilation) {
7695 torch::Tensor output = torch::max_pool3d(
7696 input,
7697 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7698 /*stride=*/{stride, stride, stride},
7699 /*padding=*/{padding, padding, padding},
7700 /*dilation=*/{dilation, dilation, dilation},
7701 /*ceil_mode=*/ceil_mode);
7702 ForEachDevice([&](const torch::Device& device) {
7703 torch::Tensor lazy_input = CopyToDevice(input, device);
7704 torch::Tensor lazy_output = torch::max_pool3d(
7705 lazy_input,
7706 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7707 /*stride=*/{stride, stride, stride},
7708 /*padding=*/{padding, padding, padding},
7709 /*dilation=*/{dilation, dilation, dilation},
7710 /*ceil_mode=*/ceil_mode);
7711 AllClose(output, lazy_output);
7712 });
7713 }
7714 }
7715 }
7716 }
7717}
7718
7719TEST_F(LazyOpsTest, TestAvgPool1D) {
7720 torch::Tensor input = torch::rand(
7721 {4, 1, 28}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7722 int kernel_size = 2;
7723 for (int stride = 1; stride <= 2; ++stride) {
7724 for (int padding = 0; padding <= 1; ++padding) {
7725 for (bool count_include_pad : {true, false}) {
7726 // Test ceil_mode=true through the CPU interop.
7727 for (bool ceil_mode : {false, true}) {
7728 torch::Tensor output = torch::avg_pool1d(
7729 input,
7730 /*kernel_size=*/{kernel_size},
7731 /*stride=*/{stride},
7732 /*padding=*/{padding},
7733 /*ceil_mode=*/ceil_mode,
7734 /*count_include_pad=*/count_include_pad);
7735 ForEachDevice([&](const torch::Device& device) {
7736 torch::Tensor lazy_input = CopyToDevice(input, device);
7737 torch::Tensor lazy_output = torch::avg_pool1d(
7738 lazy_input,
7739 /*kernel_size=*/{kernel_size},
7740 /*stride=*/{stride},
7741 /*padding=*/{padding},
7742 /*ceil_mode=*/ceil_mode,
7743 /*count_include_pad=*/count_include_pad);
7744 AllClose(output, lazy_output);
7745 });
7746 }
7747 }
7748 }
7749 }
7750}
7751
7752TEST_F(LazyOpsTest, TestAvgPool2D) {
7753 torch::Tensor input = torch::rand(
7754 {2, 1, 14, 14},
7755 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7756 int kernel_size = 2;
7757 for (int stride = 1; stride <= 2; ++stride) {
7758 for (int padding = 0; padding <= 1; ++padding) {
7759 for (bool count_include_pad : {true, false}) {
7760 // Test ceil_mode=true through the CPU interop.
7761 for (bool ceil_mode : {false, true}) {
7762 torch::Tensor output = torch::avg_pool2d(
7763 input,
7764 /*kernel_size=*/{kernel_size, kernel_size},
7765 /*stride=*/{stride, stride},
7766 /*padding=*/{padding, padding},
7767 /*ceil_mode=*/ceil_mode,
7768 /*count_include_pad=*/count_include_pad);
7769 ForEachDevice([&](const torch::Device& device) {
7770 // torch::Tensor lazy_input = CopyToDevice(input, device);
7771 torch::Tensor lazy_input = CopyToDevice(input, device);
7772 torch::Tensor lazy_output = torch::avg_pool2d(
7773 lazy_input,
7774 /*kernel_size=*/{kernel_size, kernel_size},
7775 /*stride=*/{stride, stride},
7776 /*padding=*/{padding, padding},
7777 /*ceil_mode=*/ceil_mode,
7778 /*count_include_pad=*/count_include_pad);
7779 AllClose(output, lazy_output.to(torch::kCPU));
7780 });
7781 }
7782 }
7783 }
7784 }
7785}
7786
7787TEST_F(LazyOpsTest, TestAvgPool2DNonSquare) {
7788 torch::Tensor input = torch::rand(
7789 {2, 1, 14, 14},
7790 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7791 int kernel_size = 4;
7792 for (int stride = 1; stride <= 2; ++stride) {
7793 for (int padding = 0; padding <= 1; ++padding) {
7794 for (bool count_include_pad : {true, false}) {
7795 // Test ceil_mode=true through the CPU interop.
7796 for (bool ceil_mode : {false, true}) {
7797 torch::Tensor output = torch::avg_pool2d(
7798 input,
7799 /*kernel_size=*/{kernel_size, kernel_size + 1},
7800 /*stride=*/{stride, stride + 1},
7801 /*padding=*/{padding, padding + 1},
7802 /*ceil_mode=*/ceil_mode,
7803 /*count_include_pad=*/count_include_pad);
7804 ForEachDevice([&](const torch::Device& device) {
7805 torch::Tensor lazy_input = CopyToDevice(input, device);
7806 torch::Tensor lazy_output = torch::avg_pool2d(
7807 lazy_input,
7808 /*kernel_size=*/{kernel_size, kernel_size + 1},
7809 /*stride=*/{stride, stride + 1},
7810 /*padding=*/{padding, padding + 1},
7811 /*ceil_mode=*/ceil_mode,
7812 /*count_include_pad=*/count_include_pad);
7813 AllClose(output, lazy_output);
7814 });
7815 }
7816 }
7817 }
7818 }
7819}
7820
7821TEST_F(LazyOpsTest, TestAvgPool3D) {
7822 torch::Tensor input = torch::rand(
7823 {1, 1, 7, 7, 7},
7824 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7825 int kernel_size = 2;
7826 for (int stride = 1; stride <= 2; ++stride) {
7827 for (int padding = 0; padding <= 1; ++padding) {
7828 for (bool count_include_pad : {true, false}) {
7829 // Test ceil_mode=true through the CPU interop.
7830 for (bool ceil_mode : {false, true}) {
7831 torch::Tensor output = torch::avg_pool3d(
7832 input,
7833 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7834 /*stride=*/{stride, stride, stride},
7835 /*padding=*/{padding, padding, padding},
7836 /*ceil_mode=*/ceil_mode,
7837 /*count_include_pad=*/count_include_pad);
7838 ForEachDevice([&](const torch::Device& device) {
7839 torch::Tensor lazy_input = CopyToDevice(input, device);
7840 torch::Tensor lazy_output = torch::avg_pool3d(
7841 lazy_input,
7842 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7843 /*stride=*/{stride, stride, stride},
7844 /*padding=*/{padding, padding, padding},
7845 /*ceil_mode=*/ceil_mode,
7846 /*count_include_pad=*/count_include_pad);
7847 AllClose(output, lazy_output);
7848 });
7849 }
7850 }
7851 }
7852 }
7853}
7854
7855TEST_F(LazyOpsTest, TestAvgPool3DIncompleteAttributes) {
7856 torch::Tensor input = torch::rand(
7857 {1, 1, 7, 7, 7},
7858 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7859 int kernel_size = 2;
7860 for (int stride = 1; stride <= 2; ++stride) {
7861 for (int padding = 0; padding <= 1; ++padding) {
7862 for (bool count_include_pad : {true, false}) {
7863 // Test ceil_mode=true through the CPU interop.
7864 for (bool ceil_mode : {false, true}) {
7865 torch::Tensor output = torch::avg_pool3d(
7866 input,
7867 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7868 /*stride=*/{},
7869 /*padding=*/{padding, padding, padding},
7870 /*ceil_mode=*/ceil_mode,
7871 /*count_include_pad=*/count_include_pad);
7872 ForEachDevice([&](const torch::Device& device) {
7873 torch::Tensor lazy_input = CopyToDevice(input, device);
7874 torch::Tensor lazy_output = torch::avg_pool3d(
7875 lazy_input,
7876 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7877 /*stride=*/{},
7878 /*padding=*/{padding, padding, padding},
7879 /*ceil_mode=*/ceil_mode,
7880 /*count_include_pad=*/count_include_pad);
7881 AllClose(output, lazy_output);
7882 });
7883 }
7884 }
7885 }
7886 }
7887}
7888
7889TEST_F(LazyOpsTest, TestAvgPool3DNonSquare) {
7890 torch::Tensor input = torch::rand(
7891 {1, 1, 7, 7, 7},
7892 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7893 int kernel_size = 4;
7894 for (int stride = 1; stride <= 2; ++stride) {
7895 for (int padding = 0; padding <= 1; ++padding) {
7896 for (bool count_include_pad : {true, false}) {
7897 // Test ceil_mode=true through the CPU interop.
7898 for (bool ceil_mode : {false, true}) {
7899 torch::Tensor output = torch::avg_pool3d(
7900 input,
7901 /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7902 /*stride=*/{stride, stride + 1, stride},
7903 /*padding=*/{padding, padding + 1, padding},
7904 /*ceil_mode=*/ceil_mode,
7905 /*count_include_pad=*/count_include_pad);
7906 ForEachDevice([&](const torch::Device& device) {
7907 torch::Tensor lazy_input = CopyToDevice(input, device);
7908 torch::Tensor lazy_output = torch::avg_pool3d(
7909 lazy_input,
7910 /*kernel_size=*/{kernel_size, kernel_size + 1, kernel_size},
7911 /*stride=*/{stride, stride + 1, stride},
7912 /*padding=*/{padding, padding + 1, padding},
7913 /*ceil_mode=*/ceil_mode,
7914 /*count_include_pad=*/count_include_pad);
7915 AllClose(output, lazy_output);
7916 });
7917 }
7918 }
7919 }
7920 }
7921}
7922
7923TEST_F(LazyOpsTest, TestAvgPool2DNoBatch) {
7924 torch::Tensor input = torch::rand(
7925 {1, 7, 7}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7926 int kernel_size = 2;
7927 for (int stride = 1; stride <= 2; ++stride) {
7928 for (int padding = 0; padding <= 1; ++padding) {
7929 for (bool count_include_pad : {true, false}) {
7930 // Test ceil_mode=true through the CPU interop.
7931 for (bool ceil_mode : {false, true}) {
7932 torch::Tensor output = torch::avg_pool2d(
7933 input,
7934 /*kernel_size=*/{kernel_size, kernel_size},
7935 /*stride=*/{stride, stride},
7936 /*padding=*/{padding, padding},
7937 /*ceil_mode=*/ceil_mode,
7938 /*count_include_pad=*/count_include_pad);
7939 ForEachDevice([&](const torch::Device& device) {
7940 torch::Tensor lazy_input = CopyToDevice(input, device);
7941 torch::Tensor lazy_output = torch::avg_pool2d(
7942 lazy_input,
7943 /*kernel_size=*/{kernel_size, kernel_size},
7944 /*stride=*/{stride, stride},
7945 /*padding=*/{padding, padding},
7946 /*ceil_mode=*/ceil_mode,
7947 /*count_include_pad=*/count_include_pad);
7948 AllClose(output, lazy_output);
7949 });
7950 }
7951 }
7952 }
7953 }
7954}
7955
7956TEST_F(LazyOpsTest, TestAvgPool3DNoBatch) {
7957 torch::Tensor input = torch::rand(
7958 {1, 7, 7, 7},
7959 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7960 int kernel_size = 2;
7961 for (int stride = 1; stride <= 2; ++stride) {
7962 for (int padding = 0; padding <= 1; ++padding) {
7963 for (bool count_include_pad : {true, false}) {
7964 // Test ceil_mode=true through the CPU interop.
7965 for (bool ceil_mode : {false, true}) {
7966 torch::Tensor output = torch::avg_pool3d(
7967 input,
7968 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7969 /*stride=*/{stride, stride, stride},
7970 /*padding=*/{padding, padding, padding},
7971 /*ceil_mode=*/ceil_mode,
7972 /*count_include_pad=*/count_include_pad);
7973 ForEachDevice([&](const torch::Device& device) {
7974 torch::Tensor lazy_input = CopyToDevice(input, device);
7975 torch::Tensor lazy_output = torch::avg_pool3d(
7976 lazy_input,
7977 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
7978 /*stride=*/{stride, stride, stride},
7979 /*padding=*/{padding, padding, padding},
7980 /*ceil_mode=*/ceil_mode,
7981 /*count_include_pad=*/count_include_pad);
7982 AllClose(output, lazy_output);
7983 });
7984 }
7985 }
7986 }
7987 }
7988}
7989
7990TEST_F(LazyOpsTest, TestAdaptiveAvgPool2D) {
7991 torch::Tensor input = torch::rand(
7992 {4, 1, 28, 28},
7993 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
7994 for (int64_t output_size : {7, 4}) {
7995 torch::Tensor output =
7996 torch::adaptive_avg_pool2d(input, {output_size, output_size});
7997 ForEachDevice([&](const torch::Device& device) {
7998 torch::Tensor lazy_input = CopyToDevice(input, device);
7999 torch::Tensor lazy_output =
8000 torch::adaptive_avg_pool2d(lazy_input, {output_size, output_size});
8001 AllClose(output, lazy_output);
8002 });
8003 }
8004}
8005
8006TEST_F(LazyOpsTest, TestAdaptiveAvgPool3D) {
8007 torch::Tensor input = torch::rand(
8008 {9, 4, 56, 28, 28},
8009 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8010 for (int64_t output_size : {7, 4}) {
8011 torch::Tensor output = torch::adaptive_avg_pool3d(
8012 input, {output_size, output_size, output_size});
8013 ForEachDevice([&](const torch::Device& device) {
8014 torch::Tensor lazy_input = CopyToDevice(input, device);
8015 torch::Tensor lazy_output = torch::adaptive_avg_pool3d(
8016 lazy_input, {output_size, output_size, output_size});
8017 AllClose(output, lazy_output);
8018 });
8019 }
8020}
8021
8022TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DNoBatch) {
8023 torch::Tensor input = torch::rand(
8024 {3, 56, 28, 28},
8025 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8026 for (int64_t output_size : {7, 4}) {
8027 torch::Tensor output = torch::adaptive_avg_pool3d(
8028 input, {output_size, output_size, output_size});
8029 ForEachDevice([&](const torch::Device& device) {
8030 torch::Tensor lazy_input = CopyToDevice(input, device);
8031 torch::Tensor lazy_output = torch::adaptive_avg_pool3d(
8032 lazy_input, {output_size, output_size, output_size});
8033 AllClose(output, lazy_output);
8034 });
8035 }
8036}
8037
8038TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DNoBatch) {
8039 torch::Tensor input = torch::rand(
8040 {1, 56, 56}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8041 for (int64_t output_size : {7, 8}) {
8042 torch::Tensor output =
8043 torch::adaptive_avg_pool2d(input, {output_size, output_size});
8044 ForEachDevice([&](const torch::Device& device) {
8045 torch::Tensor lazy_input = CopyToDevice(input, device);
8046 torch::Tensor lazy_output =
8047 torch::adaptive_avg_pool2d(lazy_input, {output_size, output_size});
8048 AllClose(output, lazy_output);
8049 });
8050 }
8051}
8052
8053TEST_F(LazyOpsTest, TestMaxUnpool2D) {
8054 int kernel_size = 2;
8055 torch::Tensor input = torch::rand(
8056 {2, 2, 8, 8},
8057 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8058 for (int stride = 1; stride <= 2; ++stride) {
8059 for (int padding = 0; padding <= 1; ++padding) {
8060 // Test ceil_mode=true through the CPU interop.
8061 for (bool ceil_mode : {false, true}) {
8062 // Test dilation through the CPU interop.
8063 for (int dilation = 1; dilation <= 2; ++dilation) {
8064 torch::Tensor output;
8065 torch::Tensor indices;
8066 std::tie(output, indices) = torch::max_pool2d_with_indices(
8067 input,
8068 /*kernel_size=*/{kernel_size, kernel_size},
8069 /*stride=*/{stride, stride},
8070 /*padding=*/{padding, padding},
8071 /*dilation=*/{dilation, dilation},
8072 /*ceil_mode=*/ceil_mode);
8073
8074 std::vector<int64_t> output_size({input.size(2), input.size(3)});
8075 at::Tensor utensor =
8076 torch::max_unpool2d(output, indices, output_size);
8077
8078 ForEachDevice([&](const torch::Device& device) {
8079 torch::Tensor lazy_output = CopyToDevice(output, device);
8080 torch::Tensor lazy_indices = CopyToDevice(indices, device);
8081 at::Tensor lazy_utensor =
8082 torch::max_unpool2d(lazy_output, lazy_indices, output_size);
8083 AllClose(utensor, lazy_utensor);
8084 });
8085 }
8086 }
8087 }
8088 }
8089}
8090
8091TEST_F(LazyOpsTest, TestMaxUnpool3D) {
8092 int kernel_size = 2;
8093 torch::Tensor input = torch::rand(
8094 {1, 1, 4, 4, 4},
8095 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8096 for (int stride = 1; stride <= 2; ++stride) {
8097 for (int padding = 0; padding <= 1; ++padding) {
8098 // Test ceil_mode=true through the CPU interop.
8099 for (bool ceil_mode : {false, true}) {
8100 // Test dilation through the CPU interop.
8101 for (int dilation = 1; dilation <= 2; ++dilation) {
8102 torch::Tensor output;
8103 torch::Tensor indices;
8104 std::tie(output, indices) = torch::max_pool3d_with_indices(
8105 input,
8106 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
8107 /*stride=*/{stride, stride, stride},
8108 /*padding=*/{padding, padding, padding},
8109 /*dilation=*/{dilation, dilation, dilation},
8110 /*ceil_mode=*/ceil_mode);
8111
8112 std::vector<int64_t> output_size(
8113 {input.size(2), input.size(3), input.size(4)});
8114 at::Tensor utensor = torch::max_unpool3d(
8115 output,
8116 indices,
8117 output_size,
8118 /*stride=*/{stride, stride, stride},
8119 /*padding=*/{padding, padding, padding});
8120
8121 ForEachDevice([&](const torch::Device& device) {
8122 torch::Tensor lazy_output = CopyToDevice(output, device);
8123 torch::Tensor lazy_indices = CopyToDevice(indices, device);
8124 at::Tensor lazy_utensor = torch::max_unpool3d(
8125 lazy_output,
8126 lazy_indices,
8127 output_size,
8128 /*stride=*/{stride, stride, stride},
8129 /*padding=*/{padding, padding, padding});
8130 AllClose(utensor, lazy_utensor);
8131 });
8132 }
8133 }
8134 }
8135 }
8136}
8137
8138TEST_F(LazyOpsTest, TestNllLoss) {
8139 // TODO(whc) debug divide-by-zero failure under ASAN
8140 GTEST_SKIP();
8141
8142 int batch = 6;
8143 int classes = 2;
8144 // TODO(asuhan): Fix the torch::kDouble case.
8145 for (auto dtype : {torch::kFloat}) {
8146 for (int ignore_index : {-1, 0, 1, 5}) {
8147 for (bool def_weight : {false, true}) {
8148 torch::Tensor input = torch::rand(
8149 {batch, classes},
8150 torch::TensorOptions(dtype).device(DefaultDevice()));
8151 torch::Tensor target = torch::randint(
8152 std::min(ignore_index, 0),
8153 classes,
8154 {batch},
8155 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
8156 torch::Tensor weight;
8157 if (def_weight) {
8158 weight = torch::rand(
8159 {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
8160 }
8161 for (torch::Reduction::Reduction reduction :
8162 {torch::Reduction::Mean,
8163 torch::Reduction::Sum,
8164 torch::Reduction::None}) {
8165 torch::Tensor output = torch::nll_loss(
8166 /*self=*/input,
8167 /*target=*/target,
8168 /*weight=*/weight,
8169 /*reduction=*/reduction,
8170 /*ignore_index=*/ignore_index);
8171
8172 ForEachDevice([&](const torch::Device& device) {
8173 torch::Tensor lazy_input = CopyToDevice(input, device);
8174 torch::Tensor lazy_target = CopyToDevice(target, device);
8175 torch::Tensor lazy_weight =
8176 def_weight ? CopyToDevice(weight, device) : torch::Tensor();
8177 torch::Tensor lazy_output = torch::nll_loss(
8178 /*self=*/lazy_input,
8179 /*target=*/lazy_target,
8180 /*weight=*/lazy_weight,
8181 /*reduction=*/reduction,
8182 /*ignore_index=*/ignore_index);
8183 AllClose(output, lazy_output);
8184 });
8185 }
8186 }
8187 }
8188 }
8189}
8190
8191TEST_F(LazyOpsTest, TestNllLoss2d) {
8192 int batch = 6;
8193 int classes = 2;
8194 int height = 3;
8195 int width = 3;
8196 // TODO(asuhan): Fix the torch::kDouble case.
8197 for (auto dtype : {torch::kFloat}) {
8198 for (int ignore_index : {-1, 0, 1, 5}) {
8199 for (bool def_weight : {false, true}) {
8200 torch::Tensor input = torch::rand(
8201 {batch, classes, height, width},
8202 torch::TensorOptions(dtype).device(DefaultDevice()));
8203 torch::Tensor target = torch::randint(
8204 std::min(ignore_index, 0),
8205 classes,
8206 {batch, height, width},
8207 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
8208 torch::Tensor weight;
8209 if (def_weight) {
8210 weight = torch::rand(
8211 {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
8212 }
8213 for (torch::Reduction::Reduction reduction :
8214 {torch::Reduction::Mean,
8215 torch::Reduction::Sum,
8216 torch::Reduction::None}) {
8217 torch::Tensor output = torch::nll_loss2d(
8218 /*self=*/input,
8219 /*target=*/target,
8220 /*weight=*/weight,
8221 /*reduction=*/reduction,
8222 /*ignore_index=*/ignore_index);
8223
8224 ForEachDevice([&](const torch::Device& device) {
8225 torch::Tensor lazy_input = CopyToDevice(input, device);
8226 torch::Tensor lazy_target = CopyToDevice(target, device);
8227 torch::Tensor lazy_weight =
8228 def_weight ? CopyToDevice(weight, device) : torch::Tensor();
8229 torch::Tensor lazy_output = torch::nll_loss2d(
8230 /*self=*/lazy_input,
8231 /*target=*/lazy_target,
8232 /*weight=*/lazy_weight,
8233 /*reduction=*/reduction,
8234 /*ignore_index=*/ignore_index);
8235 AllClose(output, lazy_output);
8236 });
8237 }
8238 }
8239 }
8240 }
8241}
8242
8243TEST_F(LazyOpsTest, TestSmoothL1Loss) {
8244 torch::Tensor input = torch::randn(
8245 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8246 torch::Tensor target = torch::randn(
8247 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8248 for (torch::Reduction::Reduction reduction :
8249 {torch::Reduction::None,
8250 torch::Reduction::Mean,
8251 torch::Reduction::Sum}) {
8252 for (double beta : {0.25, 1.}) {
8253 torch::Tensor output =
8254 torch::smooth_l1_loss(input, target, reduction, beta);
8255 ForEachDevice([&](const torch::Device& device) {
8256 torch::Tensor lazy_input = CopyToDevice(input, device);
8257 torch::Tensor lazy_target = CopyToDevice(target, device);
8258 torch::Tensor lazy_output =
8259 torch::smooth_l1_loss(lazy_input, lazy_target, reduction, beta);
8260 AllClose(output, lazy_output);
8261 });
8262 }
8263 }
8264}
8265
8266TEST_F(LazyOpsTest, TestL1Loss) {
8267 torch::Tensor input = torch::randn(
8268 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8269 torch::Tensor target = torch::randn(
8270 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8271 for (torch::Reduction::Reduction reduction :
8272 {torch::Reduction::None,
8273 torch::Reduction::Mean,
8274 torch::Reduction::Sum}) {
8275 torch::Tensor output = torch::l1_loss(input, target, reduction);
8276 ForEachDevice([&](const torch::Device& device) {
8277 torch::Tensor lazy_input = CopyToDevice(input, device);
8278 torch::Tensor lazy_target = CopyToDevice(target, device);
8279 torch::Tensor lazy_output =
8280 torch::l1_loss(lazy_input, lazy_target, reduction);
8281 AllClose(output, lazy_output);
8282 });
8283 }
8284}
8285
8286TEST_F(LazyOpsTest, TestL1LossBackward) {
8287 for (torch::Reduction::Reduction reduction :
8288 {torch::Reduction::None,
8289 torch::Reduction::Mean,
8290 torch::Reduction::Sum}) {
8291 auto testfn =
8292 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
8293 return torch::l1_loss(inputs[0], inputs[1], reduction);
8294 };
8295 ForEachDevice([&](const torch::Device& device) {
8296 TestBackward(
8297 {torch::rand(
8298 {2, 4},
8299 torch::TensorOptions(torch::kFloat)
8300 .device(DefaultDevice())
8301 .requires_grad(true)),
8302 torch::rand(
8303 {2, 4},
8304 torch::TensorOptions(torch::kFloat).device(DefaultDevice()))},
8305 device,
8306 testfn);
8307 });
8308 }
8309}
8310
8311TEST_F(LazyOpsTest, TestMseLoss) {
8312 torch::Tensor input = torch::randn(
8313 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8314 torch::Tensor target = torch::randn(
8315 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8316 for (torch::Reduction::Reduction reduction :
8317 {torch::Reduction::None,
8318 torch::Reduction::Mean,
8319 torch::Reduction::Sum}) {
8320 torch::Tensor output = torch::mse_loss(input, target, reduction);
8321 ForEachDevice([&](const torch::Device& device) {
8322 torch::Tensor lazy_input = CopyToDevice(input, device);
8323 torch::Tensor lazy_target = CopyToDevice(target, device);
8324 torch::Tensor lazy_output =
8325 torch::mse_loss(lazy_input, lazy_target, reduction);
8326 AllClose(output, lazy_output);
8327 });
8328 }
8329}
8330
8331TEST_F(LazyOpsTest, TestMseLossBackward) {
8332 for (torch::Reduction::Reduction reduction :
8333 {torch::Reduction::None,
8334 torch::Reduction::Mean,
8335 torch::Reduction::Sum}) {
8336 auto testfn =
8337 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
8338 return torch::mse_loss(inputs[0], inputs[1], reduction);
8339 };
8340 ForEachDevice([&](const torch::Device& device) {
8341 TestBackward(
8342 {torch::rand(
8343 {2, 4},
8344 torch::TensorOptions(torch::kFloat)
8345 .device(DefaultDevice())
8346 .requires_grad(true)),
8347 torch::rand(
8348 {2, 4},
8349 torch::TensorOptions(torch::kFloat).device(DefaultDevice()))},
8350 device,
8351 testfn);
8352 });
8353 }
8354}
8355
8356TEST_F(LazyOpsTest, TestBatchNorm1D) {
8357 int num_features = 3;
8358 torch::Tensor input = torch::rand(
8359 {2, num_features, 4},
8360 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8361 torch::Tensor weight = torch::rand(
8362 {num_features},
8363 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8364 torch::Tensor bias = torch::rand(
8365 {num_features},
8366 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8367 torch::Tensor running_mean = torch::zeros(
8368 {num_features},
8369 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8370 torch::Tensor running_var = torch::ones(
8371 {num_features},
8372 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8373 double momentum = 0.1;
8374 double eps = 0.5;
8375 torch::Tensor undef;
8376 for (bool training : {true, false}) {
8377 for (bool undef_weight_bias : {false, true}) {
8378 torch::Tensor output = torch::batch_norm(
8379 /*input=*/input,
8380 /*weight=*/undef_weight_bias ? undef : weight,
8381 /*bias=*/undef_weight_bias ? undef : bias,
8382 /*running_mean=*/running_mean,
8383 /*running_var=*/running_var,
8384 /*training=*/training,
8385 /*momentum=*/momentum,
8386 /*eps=*/eps,
8387 /*cudnn_enabled=*/false);
8388 ForEachDevice([&](const torch::Device& device) {
8389 torch::Tensor lazy_input = CopyToDevice(input, device);
8390 torch::Tensor lazy_weight =
8391 undef_weight_bias ? undef : CopyToDevice(weight, device);
8392 torch::Tensor lazy_bias =
8393 undef_weight_bias ? undef : CopyToDevice(bias, device);
8394 torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device);
8395 torch::Tensor lazy_running_var = CopyToDevice(running_var, device);
8396 torch::Tensor lazy_output = torch::batch_norm(
8397 /*input=*/lazy_input,
8398 /*weight=*/lazy_weight,
8399 /*bias=*/lazy_bias,
8400 /*running_mean=*/lazy_running_mean,
8401 /*running_var=*/lazy_running_var,
8402 /*training=*/training,
8403 /*momentum=*/momentum,
8404 /*eps=*/eps,
8405 /*cudnn_enabled=*/false);
8406 AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
8407 });
8408 }
8409 }
8410}
8411
8412TEST_F(LazyOpsTest, TestBatchNorm2D) {
8413 int num_features = 3;
8414 torch::Tensor input = torch::rand(
8415 {2, num_features, 4, 4},
8416 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8417 torch::Tensor weight = torch::rand(
8418 {num_features},
8419 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8420 torch::Tensor bias = torch::rand(
8421 {num_features},
8422 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8423 torch::Tensor running_mean = torch::zeros(
8424 {num_features},
8425 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8426 torch::Tensor running_var = torch::ones(
8427 {num_features},
8428 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8429 double momentum = 0.1;
8430 double eps = 0.5;
8431 torch::Tensor undef;
8432 for (bool training : {true, false}) {
8433 for (bool undef_weight_bias : {false, true}) {
8434 torch::Tensor output = torch::batch_norm(
8435 /*input=*/input,
8436 /*weight=*/undef_weight_bias ? undef : weight,
8437 /*bias=*/undef_weight_bias ? undef : bias,
8438 /*running_mean=*/running_mean,
8439 /*running_var=*/running_var,
8440 /*training=*/training,
8441 /*momentum=*/momentum,
8442 /*eps=*/eps,
8443 /*cudnn_enabled=*/false);
8444 ForEachDevice([&](const torch::Device& device) {
8445 torch::Tensor lazy_input = CopyToDevice(input, device);
8446 torch::Tensor lazy_weight =
8447 undef_weight_bias ? undef : CopyToDevice(weight, device);
8448 torch::Tensor lazy_bias =
8449 undef_weight_bias ? undef : CopyToDevice(bias, device);
8450 torch::Tensor lazy_running_mean = CopyToDevice(running_mean, device);
8451 torch::Tensor lazy_running_var = CopyToDevice(running_var, device);
8452 torch::Tensor lazy_output = torch::batch_norm(
8453 /*input=*/lazy_input,
8454 /*weight=*/lazy_weight,
8455 /*bias=*/lazy_bias,
8456 /*running_mean=*/lazy_running_mean,
8457 /*running_var=*/lazy_running_var,
8458 /*training=*/training,
8459 /*momentum=*/momentum,
8460 /*eps=*/eps,
8461 /*cudnn_enabled=*/false);
8462 AllClose(output, lazy_output, /*rtol=*/1e-3, /*atol=*/1e-5);
8463 });
8464 }
8465 }
8466}
8467
8468TEST_F(LazyOpsTest, TestDim) {
8469 torch::Tensor input = torch::rand(
8470 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8471 ForEachDevice([&](const torch::Device& device) {
8472 torch::Tensor lazy_input = CopyToDevice(input, device);
8473 EXPECT_EQ(input.dim(), lazy_input.dim());
8474 });
8475}
8476
8477TEST_F(LazyOpsTest, TestContiguous) {
8478 torch::Tensor input = torch::rand(
8479 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8480 torch::Tensor output = torch::native::contiguous(input);
8481 ForEachDevice([&](const torch::Device& device) {
8482 torch::Tensor lazy_input = CopyToDevice(input, device);
8483 torch::Tensor lazy_output = torch::native::contiguous(lazy_input);
8484 AllClose(output, lazy_output);
8485 });
8486}
8487
8488TEST_F(LazyOpsTest, TestSqueezeAll) {
8489 torch::Tensor input = torch::rand(
8490 {2, 1, 3, 1},
8491 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8492 torch::Tensor output = torch::squeeze(input);
8493 ForEachDevice([&](const torch::Device& device) {
8494 torch::Tensor lazy_input = CopyToDevice(input, device);
8495 torch::Tensor lazy_output = torch::squeeze(lazy_input);
8496 AllClose(output, lazy_output);
8497 });
8498}
8499
8500TEST_F(LazyOpsTest, TestSqueezeAllInPlace) {
8501 ForEachDevice([&](const torch::Device& device) {
8502 torch::Tensor input = torch::rand(
8503 {2, 1, 3, 1},
8504 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8505 torch::Tensor lazy_input = CopyToDevice(input, device);
8506 torch::Tensor output = input.squeeze_();
8507 torch::Tensor lazy_output = lazy_input.squeeze_();
8508 AllClose(output, lazy_output);
8509 AllClose(input, lazy_input);
8510 ASSERT_EQ(input.dim(), lazy_input.dim());
8511 for (int64_t dim_idx = 0; dim_idx < input.dim(); ++dim_idx) {
8512 ASSERT_EQ(input.size(dim_idx), lazy_input.size(dim_idx));
8513 }
8514 });
8515}
8516
8517TEST_F(LazyOpsTest, TestSqueezeOne) {
8518 torch::Tensor input = torch::rand(
8519 {2, 1, 3, 1},
8520 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8521 int rank = input.dim();
8522 for (int dim = -rank; dim < rank; ++dim) {
8523 torch::Tensor output = torch::squeeze(input, dim);
8524 ForEachDevice([&](const torch::Device& device) {
8525 torch::Tensor lazy_input = CopyToDevice(input, device);
8526 torch::Tensor lazy_output = torch::squeeze(lazy_input, dim);
8527 AllClose(output, lazy_output);
8528 });
8529 }
8530}
8531
8532TEST_F(LazyOpsTest, TestSqueezeOneInPlace) {
8533 int rank = 4;
8534 for (int dim = -rank; dim < rank; ++dim) {
8535 ForEachDevice([&](const torch::Device& device) {
8536 torch::Tensor input = torch::rand(
8537 {2, 1, 3, 1},
8538 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8539 torch::Tensor lazy_input = CopyToDevice(input, device);
8540 torch::Tensor output = input.squeeze_(dim);
8541 torch::Tensor lazy_output = lazy_input.squeeze_(dim);
8542 AllClose(output, lazy_output);
8543 AllClose(input, lazy_input);
8544 ASSERT_EQ(input.dim(), lazy_input.dim());
8545 for (int64_t dim_idx = 0; dim_idx < input.dim(); ++dim_idx) {
8546 ASSERT_EQ(input.size(dim_idx), lazy_input.size(dim_idx));
8547 }
8548 });
8549 }
8550}
8551
8552TEST_F(LazyOpsTest, TestUnsqueeze) {
8553 torch::Tensor input = torch::rand(
8554 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8555 int rank = input.dim() + 1;
8556 for (int dim = -rank; dim < rank; ++dim) {
8557 torch::Tensor output = torch::unsqueeze(input, dim);
8558 ForEachDevice([&](const torch::Device& device) {
8559 torch::Tensor lazy_input = CopyToDevice(input, device);
8560 torch::Tensor lazy_output = torch::unsqueeze(lazy_input, dim);
8561 AllClose(output, lazy_output);
8562 });
8563 }
8564}
8565
8566TEST_F(LazyOpsTest, TestUnsqueezeInPlace) {
8567 torch::Tensor input = torch::rand(
8568 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8569 int rank = input.dim() + 1;
8570 for (int dim = -rank; dim < rank; ++dim) {
8571 ForEachDevice([&](const torch::Device& device) {
8572 torch::Tensor lazy_input = CopyToDevice(input, device);
8573 torch::Tensor output = input.unsqueeze_(dim);
8574 torch::Tensor lazy_output = lazy_input.unsqueeze_(dim);
8575 AllClose(output, lazy_output);
8576 AllClose(input, lazy_input);
8577 ASSERT_EQ(input.dim(), lazy_input.dim());
8578 for (int64_t dim_idx = 0; dim_idx < input.dim(); ++dim_idx) {
8579 ASSERT_EQ(input.size(dim_idx), lazy_input.size(dim_idx));
8580 }
8581 });
8582 }
8583}
8584
8585TEST_F(LazyOpsTest, TestMaskedFill) {
8586 torch::Tensor input = torch::rand(
8587 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8588 torch::Tensor mask = torch::randint(
8589 0, 2, {2, 3}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
8590 torch::Scalar value(42);
8591 torch::Tensor result = torch::masked_fill(input, mask, value);
8592 ForEachDevice([&](const torch::Device& device) {
8593 torch::Tensor lazy_input = CopyToDevice(input, device);
8594 torch::Tensor lazy_mask = CopyToDevice(mask, device);
8595 torch::Tensor lazy_result =
8596 torch::masked_fill(lazy_input, lazy_mask, value);
8597 AllClose(result, lazy_result);
8598 });
8599}
8600
8601TEST_F(LazyOpsTest, TestMaskedFillInPlace) {
8602 torch::Scalar value(42);
8603 torch::Tensor mask = torch::randint(
8604 0, 2, {2, 3}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
8605 ForEachDevice([&](const torch::Device& device) {
8606 torch::Tensor input = torch::rand(
8607 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8608 torch::Tensor lazy_input = CopyToDevice(input, device);
8609 torch::Tensor lazy_mask = CopyToDevice(mask, device);
8610 torch::Tensor result = input.masked_fill_(mask, value);
8611 torch::Tensor lazy_result = lazy_input.masked_fill_(lazy_mask, value);
8612 AllClose(result, lazy_result);
8613 AllClose(input, lazy_input);
8614 });
8615}
8616
8617TEST_F(LazyOpsTest, TestMaskedFillBroadcast) {
8618 torch::Tensor input = torch::rand(
8619 {2, 5, 4, 3},
8620 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8621 torch::Tensor mask = torch::randint(
8622 0, 2, {4, 1}, torch::TensorOptions(torch::kBool).device(DefaultDevice()));
8623 torch::Scalar value(42);
8624 torch::Tensor result = torch::masked_fill(input, mask, value);
8625 ForEachDevice([&](const torch::Device& device) {
8626 torch::Tensor lazy_input = CopyToDevice(input, device);
8627 torch::Tensor lazy_mask = CopyToDevice(mask, device);
8628 torch::Tensor lazy_result =
8629 torch::masked_fill(lazy_input, lazy_mask, value);
8630 AllClose(result, lazy_result);
8631 });
8632}
8633
8634TEST_F(LazyOpsTest, TestFill) {
8635 torch::Scalar value(42);
8636 ForEachDevice([&](const torch::Device& device) {
8637 torch::Tensor input = torch::empty(
8638 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8639 torch::Tensor lazy_input = CopyToDevice(input, device);
8640 torch::Tensor result = torch::fill_(input, value);
8641 torch::Tensor lazy_result = torch::fill_(lazy_input, value);
8642 AllClose(result, lazy_result);
8643 AllClose(input, lazy_input);
8644 });
8645}
8646
8647TEST_F(LazyOpsTest, TestFillWithRank0) {
8648 torch::Tensor value = torch::scalar_tensor(42);
8649 ForEachDevice([&](const torch::Device& device) {
8650 torch::Tensor input = torch::empty(
8651 {2, 3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8652 torch::Tensor lazy_input = CopyToDevice(input, device);
8653 torch::Tensor result = torch::fill_(input, value);
8654 torch::Tensor lazy_value = CopyToDevice(value, device);
8655 torch::Tensor lazy_result = torch::fill_(lazy_input, value);
8656 AllClose(result, lazy_result);
8657 AllClose(input, lazy_input);
8658 });
8659}
8660
8661TEST_F(LazyOpsTest, TestPermute) {
8662 torch::Tensor input = torch::rand(
8663 {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8664 std::vector<std::vector<int64_t>> dims_permutations = {
8665 {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}};
8666 int rank = input.dim();
8667 for (std::vector<int64_t> dims_permutation : dims_permutations) {
8668 for (bool negative_dims : {false, true}) {
8669 if (negative_dims) {
8670 std::for_each(
8671 dims_permutation.begin(),
8672 dims_permutation.end(),
8673 [rank](int64_t& dim) { dim -= rank; });
8674 }
8675 torch::Tensor output = input.permute(dims_permutation);
8676 ForEachDevice([&](const torch::Device& device) {
8677 torch::Tensor lazy_input = CopyToDevice(input, device);
8678 torch::Tensor lazy_output = lazy_input.permute(dims_permutation);
8679 AllClose(output, lazy_output);
8680 });
8681 }
8682 }
8683}
8684
8685TEST_F(LazyOpsTest, TestPermuteMod) {
8686 std::vector<std::vector<int64_t>> dims_permutations = {
8687 {0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0}};
8688 std::vector<int64_t> input_sizes = {2, 3, 4};
8689 int rank = input_sizes.size();
8690 for (std::vector<int64_t> dims_permutation : dims_permutations) {
8691 for (bool negative_dims : {false, true}) {
8692 if (negative_dims) {
8693 std::for_each(
8694 dims_permutation.begin(),
8695 dims_permutation.end(),
8696 [rank](int64_t& dim) { dim -= rank; });
8697 }
8698 torch::Tensor input = torch::zeros(
8699 input_sizes,
8700 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8701 torch::Tensor one = torch::tensor(
8702 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8703 torch::Tensor output = input.permute(dims_permutation);
8704 output.add_(one, 1.0);
8705 input.add_(one, 1.0);
8706 ForEachDevice([&](const torch::Device& device) {
8707 torch::Tensor xinput = torch::zeros(
8708 input_sizes,
8709 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8710 torch::Tensor lazy_input = CopyToDevice(xinput, device);
8711 torch::Tensor lazy_one = CopyToDevice(one, device);
8712 torch::Tensor lazy_output = lazy_input.permute(dims_permutation);
8713 lazy_output.add_(lazy_one, 1.0);
8714 lazy_input.add_(lazy_one, 1.0);
8715 AllClose(output, lazy_output);
8716 AllClose(input, lazy_input);
8717 });
8718 }
8719 }
8720}
8721
8722TEST_F(LazyOpsTest, TestFlip) {
8723 torch::Tensor input = torch::rand(
8724 {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8725 std::vector<std::vector<int64_t>> dim_powerset = {
8726 {0}, {1}, {2}, {0, 1}, {1, 2}, {2, 0}, {0, 1, 2}};
8727 for (std::vector<int64_t> flip_dims : dim_powerset) {
8728 for (bool negative_dims : {false, true}) {
8729 if (negative_dims) {
8730 std::for_each(
8731 flip_dims.begin(), flip_dims.end(), [](int64_t& dim) { dim -= 3; });
8732 }
8733 torch::Tensor output = torch::flip(input, flip_dims);
8734 ForEachDevice([&](const torch::Device& device) {
8735 torch::Tensor lazy_input = CopyToDevice(input, device);
8736 torch::Tensor lazy_output = torch::flip(lazy_input, flip_dims);
8737 AllClose(output, lazy_output);
8738 });
8739 }
8740 }
8741}
8742
8743TEST_F(LazyOpsTest, TestPixelShuffle) {
8744 torch::Tensor input = torch::rand(
8745 {5, 18, 4, 4},
8746 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8747 int upscale_factor = 3;
8748 ForEachDevice([&](const torch::Device& device) {
8749 torch::Tensor lazy_input = CopyToDevice(input, device);
8750 torch::Tensor output = torch::pixel_shuffle(input, upscale_factor);
8751 torch::Tensor lazy_output =
8752 torch::pixel_shuffle(lazy_input, upscale_factor);
8753 AllClose(output, lazy_output);
8754 });
8755}
8756
8757TEST_F(LazyOpsTest, TestSumToSize) {
8758 torch::Tensor input = torch::rand(
8759 {4, 6, 3, 7},
8760 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8761 std::vector<int64_t> out_size = {4, 1, 1, 7};
8762 ForEachDevice([&](const torch::Device& device) {
8763 torch::Tensor lazy_input = CopyToDevice(input, device);
8764 torch::Tensor output = input.sum_to_size(out_size);
8765 torch::Tensor lazy_output = lazy_input.sum_to_size(out_size);
8766 AllClose(output, lazy_output);
8767 });
8768}
8769
8770TEST_F(LazyOpsTest, TestTransposeDims) {
8771 torch::Tensor input = torch::rand(
8772 {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8773 int dim0 = 0;
8774 int dim1 = 2;
8775 torch::Tensor output = torch::transpose(input, dim0, dim1);
8776 ForEachDevice([&](const torch::Device& device) {
8777 torch::Tensor lazy_input = CopyToDevice(input, device);
8778 torch::Tensor lazy_output = torch::transpose(lazy_input, dim0, dim1);
8779 AllClose(output, lazy_output);
8780 });
8781}
8782
8783TEST_F(LazyOpsTest, TestTransposeDimsMod) {
8784 std::vector<int64_t> input_sizes = {2, 3, 4};
8785 int dim0 = 0;
8786 int dim1 = 2;
8787 torch::Tensor input = torch::zeros(
8788 input_sizes, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8789 torch::Tensor one = torch::tensor(
8790 1.0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8791 torch::Tensor output = torch::transpose(input, dim0, dim1);
8792 output.add_(one, 1.0);
8793 input.add_(one, 1.0);
8794 ForEachDevice([&](const torch::Device& device) {
8795 torch::Tensor xinput = torch::zeros(
8796 input_sizes,
8797 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8798 torch::Tensor lazy_input = CopyToDevice(xinput, device);
8799 torch::Tensor lazy_one = CopyToDevice(one, device);
8800 torch::Tensor lazy_output = torch::transpose(lazy_input, dim0, dim1);
8801 lazy_output.add_(lazy_one, 1.0);
8802 lazy_input.add_(lazy_one, 1.0);
8803 AllClose(output, lazy_output);
8804 AllClose(input, lazy_input);
8805 });
8806}
8807
8808TEST_F(LazyOpsTest, TestTransposeDimsInPlace) {
8809 torch::Tensor input = torch::rand(
8810 {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8811 int dim0 = 0;
8812 int dim1 = 2;
8813 ForEachDevice([&](const torch::Device& device) {
8814 torch::Tensor lazy_input = CopyToDevice(input, device);
8815 torch::Tensor output = input.transpose_(dim0, dim1);
8816 torch::Tensor lazy_output = lazy_input.transpose_(dim0, dim1);
8817 AllClose(output, lazy_output);
8818 AllClose(input, lazy_input);
8819 });
8820}
8821
8822TEST_F(LazyOpsTest, TestSplit) {
8823 torch::Tensor input = torch::rand(
8824 {7, 8, 9}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8825 int rank = input.dim();
8826 for (int split_size : {2, 3}) {
8827 for (int dim = -rank; dim < rank; ++dim) {
8828 std::vector<torch::Tensor> outputs = torch::split(input, split_size, dim);
8829 ForEachDevice([&](const torch::Device& device) {
8830 torch::Tensor lazy_input = CopyToDevice(input, device);
8831 std::vector<torch::Tensor> lazy_outputs =
8832 torch::split(lazy_input, split_size, dim);
8833 ASSERT_EQ(outputs.size(), lazy_outputs.size());
8834 for (size_t i = 0; i < outputs.size(); ++i) {
8835 AllClose(outputs[i], lazy_outputs[i]);
8836 }
8837 });
8838 }
8839 }
8840}
8841
8842TEST_F(LazyOpsTest, TestSplitEmpty) {
8843 torch::Tensor input = torch::rand(
8844 {0}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8845 int split_size = 0;
8846 int dim = 0;
8847 std::vector<torch::Tensor> outputs = torch::split(input, split_size, dim);
8848 ForEachDevice([&](const torch::Device& device) {
8849 torch::Tensor lazy_input = CopyToDevice(input, device);
8850 std::vector<torch::Tensor> lazy_outputs =
8851 torch::split(lazy_input, split_size, dim);
8852 ASSERT_EQ(outputs.size(), lazy_outputs.size());
8853 for (size_t i = 0; i < outputs.size(); ++i) {
8854 AllClose(outputs[i], lazy_outputs[i]);
8855 }
8856 });
8857}
8858
8859TEST_F(LazyOpsTest, TestSplitWithSizes) {
8860 torch::Tensor input = torch::rand(
8861 {15, 15, 15},
8862 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8863 int rank = input.dim();
8864 for (int dim = -rank; dim < rank; ++dim) {
8865 std::vector<torch::Tensor> outputs =
8866 torch::split_with_sizes(input, {4, 5, 6}, dim);
8867 ForEachDevice([&](const torch::Device& device) {
8868 torch::Tensor lazy_input = CopyToDevice(input, device);
8869 std::vector<torch::Tensor> lazy_outputs =
8870 torch::split_with_sizes(lazy_input, {4, 5, 6}, dim);
8871 ASSERT_EQ(outputs.size(), lazy_outputs.size());
8872 for (size_t i = 0; i < outputs.size(); ++i) {
8873 AllClose(outputs[i], lazy_outputs[i]);
8874 }
8875 });
8876 }
8877}
8878
8879TEST_F(LazyOpsTest, TestCrossImplicitDim) {
8880 std::vector<std::vector<int64_t>> dim_sizes = {
8881 {4, 5, 3}, {4, 3, 5}, {3, 4, 5}};
8882 for (auto dim_size : dim_sizes) {
8883 torch::Tensor input = torch::rand(
8884 dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8885 torch::Tensor other = torch::rand(
8886 dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8887 torch::Tensor result = torch::cross(input, other);
8888 ForEachDevice([&](const torch::Device& device) {
8889 torch::Tensor lazy_input = CopyToDevice(input, device);
8890 torch::Tensor lazy_other = CopyToDevice(other, device);
8891 torch::Tensor lazy_result = torch::cross(lazy_input, lazy_other);
8892 AllClose(result, lazy_result);
8893 });
8894 }
8895}
8896
8897TEST_F(LazyOpsTest, TestCrossExplicitDim) {
8898 std::vector<int64_t> dim_size = {3, 3};
8899 torch::Tensor input = torch::rand(
8900 dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8901 torch::Tensor other = torch::rand(
8902 dim_size, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8903 int rank = dim_size.size();
8904 for (int dim = -rank; dim < rank; ++dim) {
8905 torch::Tensor result = torch::cross(input, other, dim);
8906 ForEachDevice([&](const torch::Device& device) {
8907 torch::Tensor lazy_input = CopyToDevice(input, device);
8908 torch::Tensor lazy_other = CopyToDevice(other, device);
8909 torch::Tensor lazy_result = torch::cross(lazy_input, lazy_other, dim);
8910 AllClose(result, lazy_result);
8911 });
8912 }
8913}
8914
8915TEST_F(LazyOpsTest, TestCrossZeroDim) {
8916 torch::Tensor input = torch::rand(
8917 {0, 1, 3, 0},
8918 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8919 torch::Tensor result = torch::cross(input, input);
8920 ForEachDevice([&](const torch::Device& device) {
8921 torch::Tensor lazy_input = CopyToDevice(input, device);
8922 torch::Tensor lazy_result = torch::cross(lazy_input, lazy_input);
8923 AllClose(result, lazy_result);
8924 });
8925}
8926
8927TEST_F(LazyOpsTest, TestTriu) {
8928 int size = 5;
8929 torch::Tensor input = torch::rand(
8930 {size, size},
8931 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8932 // Test all diagonals and out of bounds (must be no-op).
8933 for (int diagonal = -size; diagonal <= size; ++diagonal) {
8934 torch::Tensor output = torch::triu(input, diagonal);
8935 ForEachDevice([&](const torch::Device& device) {
8936 torch::Tensor lazy_input = CopyToDevice(input, device);
8937 torch::Tensor lazy_output = torch::triu(lazy_input, diagonal);
8938 AllClose(output, lazy_output);
8939 });
8940 }
8941}
8942
8943TEST_F(LazyOpsTest, TestTriuNonSquare) {
8944 int size = 5;
8945 torch::Tensor input = torch::rand(
8946 {size, size + 1},
8947 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8948 // Test all diagonals and out of bounds (must be no-op).
8949 for (int diagonal = -size; diagonal <= size; ++diagonal) {
8950 torch::Tensor output = torch::triu(input, diagonal);
8951 ForEachDevice([&](const torch::Device& device) {
8952 torch::Tensor lazy_input = CopyToDevice(input, device);
8953 torch::Tensor lazy_output = torch::triu(lazy_input, diagonal);
8954 AllClose(output, lazy_output);
8955 });
8956 }
8957}
8958
8959TEST_F(LazyOpsTest, TestTriuBatch) {
8960 int size = 5;
8961 int batch_size = 3;
8962 torch::Tensor input = torch::rand(
8963 {batch_size, size, size},
8964 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8965 // Test all diagonals and out of bounds (must be no-op).
8966 for (int diagonal = -size; diagonal <= size; ++diagonal) {
8967 torch::Tensor output = torch::triu(input, diagonal);
8968 ForEachDevice([&](const torch::Device& device) {
8969 torch::Tensor lazy_input = CopyToDevice(input, device);
8970 torch::Tensor lazy_output = torch::triu(lazy_input, diagonal);
8971 AllClose(output, lazy_output);
8972 });
8973 }
8974}
8975
8976TEST_F(LazyOpsTest, TestTril) {
8977 int size = 5;
8978 torch::Tensor input = torch::rand(
8979 {size, size},
8980 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8981 // Test all diagonals and out of bounds (must be no-op).
8982 for (int diagonal = -size; diagonal <= size; ++diagonal) {
8983 torch::Tensor output = torch::tril(input, diagonal);
8984 ForEachDevice([&](const torch::Device& device) {
8985 torch::Tensor lazy_input = CopyToDevice(input, device);
8986 torch::Tensor lazy_output = torch::tril(lazy_input, diagonal);
8987 AllClose(output, lazy_output);
8988 });
8989 }
8990}
8991
8992TEST_F(LazyOpsTest, TestTrilNonSquare) {
8993 int size = 5;
8994 torch::Tensor input = torch::rand(
8995 {size, size + 1},
8996 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
8997 // Test all diagonals and out of bounds (must be no-op).
8998 for (int diagonal = -size; diagonal <= size; ++diagonal) {
8999 torch::Tensor output = torch::tril(input, diagonal);
9000 ForEachDevice([&](const torch::Device& device) {
9001 torch::Tensor lazy_input = CopyToDevice(input, device);
9002 torch::Tensor lazy_output = torch::tril(lazy_input, diagonal);
9003 AllClose(output, lazy_output);
9004 });
9005 }
9006}
9007
9008TEST_F(LazyOpsTest, TestTrilBatch) {
9009 int size = 5;
9010 int batch_size = 3;
9011 torch::Tensor input = torch::rand(
9012 {batch_size, size, size},
9013 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9014 // Test all diagonals and out of bounds (must be no-op).
9015 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9016 torch::Tensor output = torch::tril(input, diagonal);
9017 ForEachDevice([&](const torch::Device& device) {
9018 torch::Tensor lazy_input = CopyToDevice(input, device);
9019 torch::Tensor lazy_output = torch::tril(lazy_input, diagonal);
9020 AllClose(output, lazy_output);
9021 });
9022 }
9023}
9024
9025TEST_F(LazyOpsTest, TestTriuInPlace) {
9026 int size = 5;
9027 // Test all diagonals and out of bounds (must be no-op).
9028 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9029 ForEachDevice([&](const torch::Device& device) {
9030 torch::Tensor input = torch::rand(
9031 {size, size},
9032 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9033 torch::Tensor lazy_input = CopyToDevice(input, device);
9034 torch::Tensor output = input.triu_(diagonal);
9035 torch::Tensor lazy_output = lazy_input.triu_(diagonal);
9036 AllClose(output, lazy_output);
9037 AllClose(input, lazy_input);
9038 });
9039 }
9040}
9041
9042TEST_F(LazyOpsTest, TestTrilInPlace) {
9043 int size = 5;
9044 // Test all diagonals and out of bounds (must be no-op).
9045 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9046 ForEachDevice([&](const torch::Device& device) {
9047 torch::Tensor input = torch::rand(
9048 {size, size},
9049 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9050 torch::Tensor lazy_input = CopyToDevice(input, device);
9051 torch::Tensor output = input.tril_(diagonal);
9052 torch::Tensor lazy_output = lazy_input.tril_(diagonal);
9053 AllClose(output, lazy_output);
9054 AllClose(input, lazy_input);
9055 });
9056 }
9057}
9058
9059TEST_F(LazyOpsTest, TestTrace) {
9060 int n = 5;
9061 torch::Tensor input = torch::rand(
9062 {n, n}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9063 torch::Tensor output = torch::trace(input);
9064 ForEachDevice([&](const torch::Device& device) {
9065 torch::Tensor lazy_input = CopyToDevice(input, device);
9066 torch::Tensor lazy_output = torch::trace(lazy_input);
9067 AllClose(output, lazy_output);
9068 });
9069}
9070
9071TEST_F(LazyOpsTest, TestTraceWide) {
9072 int lines = 3;
9073 int cols = 5;
9074 torch::Tensor input = torch::rand(
9075 {lines, cols},
9076 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9077 torch::Tensor output = torch::trace(input);
9078 ForEachDevice([&](const torch::Device& device) {
9079 torch::Tensor lazy_input = CopyToDevice(input, device);
9080 torch::Tensor lazy_output = torch::trace(lazy_input);
9081 AllClose(output, lazy_output);
9082 });
9083}
9084
9085TEST_F(LazyOpsTest, TestTraceNarrow) {
9086 int lines = 5;
9087 int cols = 3;
9088 torch::Tensor input = torch::rand(
9089 {lines, cols},
9090 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9091 torch::Tensor output = torch::trace(input);
9092 ForEachDevice([&](const torch::Device& device) {
9093 torch::Tensor lazy_input = CopyToDevice(input, device);
9094 torch::Tensor lazy_output = torch::trace(lazy_input);
9095 AllClose(output, lazy_output);
9096 });
9097}
9098
9099TEST_F(LazyOpsTest, TestDiagRank1) {
9100 int size = 7;
9101 torch::Tensor input = torch::rand(
9102 {size}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9103 // Test all diagonals and out of bounds (must be no-op).
9104 for (int diagonal = -2 * size; diagonal <= 2 * size; ++diagonal) {
9105 torch::Tensor output = torch::diag(input, diagonal);
9106 ForEachDevice([&](const torch::Device& device) {
9107 torch::Tensor lazy_input = CopyToDevice(input, device);
9108 torch::Tensor lazy_output = torch::diag(lazy_input, diagonal);
9109 AllClose(output, lazy_output);
9110 });
9111 }
9112}
9113
9114TEST_F(LazyOpsTest, TestDiagRank2) {
9115 int size = 7;
9116 torch::Tensor input = torch::rand(
9117 {size, size},
9118 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9119 // Test all diagonals and out of bounds (must be no-op).
9120 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9121 torch::Tensor output = torch::diag(input, diagonal);
9122 ForEachDevice([&](const torch::Device& device) {
9123 torch::Tensor lazy_input = CopyToDevice(input, device);
9124 torch::Tensor lazy_output = torch::diag(lazy_input, diagonal);
9125 AllClose(output, lazy_output);
9126 });
9127 }
9128}
9129
9130TEST_F(LazyOpsTest, TestDiagFlat) {
9131 torch::Tensor input = torch::rand(
9132 {4, 3, 6, 7},
9133 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9134 for (int diagonal = -10; diagonal < 10; ++diagonal) {
9135 torch::Tensor output = torch::diagflat(input, diagonal);
9136 ForEachDevice([&](const torch::Device& device) {
9137 torch::Tensor lazy_input = CopyToDevice(input, device);
9138 torch::Tensor lazy_output = torch::diagflat(lazy_input, diagonal);
9139 AllClose(output, lazy_output);
9140 });
9141 }
9142}
9143
9144TEST_F(LazyOpsTest, TestDiagonal) {
9145 int size = 5;
9146 torch::Tensor input = torch::rand(
9147 {size, size},
9148 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9149 // Test all diagonals and out of bounds (must be no-op).
9150 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9151 torch::Tensor output = torch::diagonal(input, diagonal);
9152 ForEachDevice([&](const torch::Device& device) {
9153 torch::Tensor lazy_input = CopyToDevice(input, device);
9154 torch::Tensor lazy_output = torch::diagonal(lazy_input, diagonal);
9155 AllClose(output, lazy_output);
9156 });
9157 }
9158}
9159
9160TEST_F(LazyOpsTest, TestDiagonalUpdate) {
9161 int size = 5;
9162 // Test all diagonals and out of bounds (must be no-op).
9163 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9164 auto input = torch::rand(
9165 {size, size},
9166 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9167 auto input_clone = input.clone();
9168 auto output = torch::diagonal(input, diagonal);
9169 output.add_(1);
9170
9171 ForEachDevice([&](const torch::Device& device) {
9172 torch::Tensor lazy_input = CopyToDevice(input_clone, device);
9173 torch::Tensor lazy_output = torch::diagonal(lazy_input, diagonal);
9174 lazy_output.add_(1);
9175
9176 AllClose(output, lazy_output);
9177 AllClose(input, lazy_input);
9178 });
9179 }
9180}
9181
9182TEST_F(LazyOpsTest, TestDiagonalNonSquare) {
9183 int size = 5;
9184 torch::Tensor input = torch::rand(
9185 {size, size + 1},
9186 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9187 // Test all diagonals and out of bounds (must be no-op).
9188 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9189 torch::Tensor output = torch::diagonal(input, diagonal);
9190 ForEachDevice([&](const torch::Device& device) {
9191 torch::Tensor lazy_input = CopyToDevice(input, device);
9192 torch::Tensor lazy_output = torch::diagonal(lazy_input, diagonal);
9193 AllClose(output, lazy_output);
9194 });
9195 }
9196}
9197
9198TEST_F(LazyOpsTest, TestDiagonalBatch) {
9199 int size = 5;
9200 int batch_size = 3;
9201 int dim1 = 1;
9202 int dim2 = 2;
9203 torch::Tensor input = torch::rand(
9204 {batch_size, size, size},
9205 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9206 // Test all diagonals and out of bounds (must be no-op).
9207 for (int diagonal = -size; diagonal <= size; ++diagonal) {
9208 torch::Tensor output =
9209 torch::diagonal(input, diagonal, /*dim1=*/dim1, /*dim1=*/dim2);
9210 ForEachDevice([&](const torch::Device& device) {
9211 torch::Tensor lazy_input = CopyToDevice(input, device);
9212 torch::Tensor lazy_output =
9213 torch::diagonal(lazy_input, diagonal, /*dim1=*/dim1, /*dim1=*/dim2);
9214 AllClose(output, lazy_output);
9215 });
9216 }
9217}
9218
9219TEST_F(LazyOpsTest, TestFlatten) {
9220 torch::Tensor input = torch::rand({4, 7, 5, 3});
9221 int rank = input.dim();
9222 for (int pos_start_dim = 0; pos_start_dim < rank; ++pos_start_dim) {
9223 for (int pos_end_dim = pos_start_dim; pos_end_dim < rank; ++pos_end_dim) {
9224 for (bool negative_start_dim : {false, true}) {
9225 for (bool negative_end_dim : {false, true}) {
9226 int start_dim =
9227 negative_start_dim ? pos_start_dim - rank : pos_start_dim;
9228 int end_dim = negative_end_dim ? pos_end_dim - rank : pos_end_dim;
9229 torch::Tensor output = torch::flatten(input, start_dim, end_dim);
9230 ForEachDevice([&](const torch::Device& device) {
9231 torch::Tensor lazy_input = CopyToDevice(input, device);
9232 torch::Tensor lazy_output =
9233 torch::flatten(lazy_input, start_dim, end_dim);
9234 AllClose(output, lazy_output);
9235 });
9236 }
9237 }
9238 }
9239 }
9240}
9241
9242TEST_F(LazyOpsTest, TestLogicalAnd) {
9243 for (torch::ScalarType scalar_type1 :
9244 {torch::kFloat,
9245 torch::kByte,
9246 torch::kChar,
9247 torch::kShort,
9248 torch::kInt,
9249 torch::kLong}) {
9250 torch::Tensor lhs = isFloatingType(scalar_type1)
9251 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type1))
9252 : torch::randint(0, 100, {3, 4}, torch::TensorOptions(scalar_type1));
9253 for (torch::ScalarType scalar_type2 :
9254 {torch::kFloat,
9255 torch::kByte,
9256 torch::kChar,
9257 torch::kShort,
9258 torch::kInt,
9259 torch::kLong}) {
9260 torch::Tensor rhs = isFloatingType(scalar_type2)
9261 ? torch::rand({3, 4}, torch::TensorOptions(scalar_type2))
9262 : torch::randint(1, 100, {3, 4}, torch::TensorOptions(scalar_type2));
9263 torch::Tensor result = torch::logical_and(lhs, rhs);
9264 ForEachDevice([&](const torch::Device& device) {
9265 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9266 torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9267 torch::Tensor lazy_result = torch::logical_and(lazy_lhs, lazy_rhs);
9268 AllEqual(result, lazy_result);
9269 });
9270 }
9271 }
9272
9273 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
9274 ExpectCounterChanged("xla::logical_and_out", GetIgnoredCounters());
9275}
9276
9277TEST_F(LazyOpsTest, TestBitwiseAnd) {
9278 torch::Tensor lhs = torch::randint(
9279 0,
9280 std::numeric_limits<int32_t>::max(),
9281 {4, 2},
9282 torch::TensorOptions(torch::kInt));
9283 torch::Tensor rhs = torch::randint(
9284 0,
9285 std::numeric_limits<int32_t>::max(),
9286 {4, 2},
9287 torch::TensorOptions(torch::kInt));
9288 torch::Tensor result = lhs.__and__(rhs);
9289 ForEachDevice([&](const torch::Device& device) {
9290 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9291 torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9292 torch::Tensor lazy_result = lazy_lhs.__and__(lazy_rhs);
9293 AllEqual(result, lazy_result);
9294 });
9295}
9296
9297TEST_F(LazyOpsTest, TestBitwiseAndInPlace) {
9298 torch::Tensor lhs = torch::randint(
9299 0,
9300 std::numeric_limits<int32_t>::max(),
9301 {4, 2},
9302 torch::TensorOptions(torch::kInt));
9303 torch::Tensor rhs = torch::randint(
9304 0,
9305 std::numeric_limits<int32_t>::max(),
9306 {4, 2},
9307 torch::TensorOptions(torch::kInt));
9308 ForEachDevice([&](const torch::Device& device) {
9309 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9310 torch::Tensor result = lhs.__iand__(rhs);
9311 torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9312 torch::Tensor lazy_result = lazy_lhs.__iand__(lazy_rhs);
9313 AllEqual(result, lazy_result);
9314 AllEqual(lhs, lazy_lhs);
9315 });
9316}
9317
9318TEST_F(LazyOpsTest, TestBitwiseAndScalar) {
9319 torch::Tensor lhs = torch::randint(
9320 0,
9321 std::numeric_limits<int32_t>::max(),
9322 {4, 2},
9323 torch::TensorOptions(torch::kInt));
9324 torch::Scalar rhs(123456789);
9325 torch::Tensor result = lhs.__and__(rhs);
9326 ForEachDevice([&](const torch::Device& device) {
9327 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9328 torch::Tensor lazy_result = lazy_lhs.__and__(rhs);
9329 AllEqual(result, lazy_result);
9330 });
9331}
9332
9333TEST_F(LazyOpsTest, TestBitwiseAndScalarInPlace) {
9334 torch::Tensor lhs = torch::randint(
9335 0,
9336 std::numeric_limits<int32_t>::max(),
9337 {4, 2},
9338 torch::TensorOptions(torch::kInt));
9339 torch::Scalar rhs(123456789);
9340 ForEachDevice([&](const torch::Device& device) {
9341 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9342 torch::Tensor result = lhs.__iand__(rhs);
9343 torch::Tensor lazy_result = lazy_lhs.__iand__(rhs);
9344 AllEqual(result, lazy_result);
9345 AllEqual(lhs, lazy_lhs);
9346 });
9347}
9348
9349TEST_F(LazyOpsTest, TestBitwiseAndPromotion) {
9350 torch::Tensor input = torch::rand(
9351 {4, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9352 torch::Tensor view = input.reshape(-1);
9353 torch::Tensor result = torch::__and__(view.gt(0), view.ne(0));
9354 ForEachDevice([&](const torch::Device& device) {
9355 torch::Tensor lazy_input = CopyToDevice(input, device);
9356 torch::Tensor lazy_view = lazy_input.reshape(-1);
9357 torch::Tensor lazy_result =
9358 torch::__and__(lazy_view.gt(0), lazy_view.ne(0));
9359 AllEqual(result, lazy_result);
9360 });
9361}
9362
9363TEST_F(LazyOpsTest, TestBitwiseOr) {
9364 torch::Tensor lhs = torch::randint(
9365 0,
9366 std::numeric_limits<int32_t>::max(),
9367 {4, 2},
9368 torch::TensorOptions(torch::kInt));
9369 torch::Tensor rhs = torch::randint(
9370 0,
9371 std::numeric_limits<int32_t>::max(),
9372 {4, 2},
9373 torch::TensorOptions(torch::kInt));
9374 torch::Tensor result = lhs.__or__(rhs);
9375 ForEachDevice([&](const torch::Device& device) {
9376 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9377 torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9378 torch::Tensor lazy_result = lazy_lhs.__or__(lazy_rhs);
9379 AllEqual(result, lazy_result);
9380 });
9381}
9382
9383TEST_F(LazyOpsTest, TestBitwiseOrInPlace) {
9384 torch::Tensor lhs = torch::randint(
9385 0,
9386 std::numeric_limits<int32_t>::max(),
9387 {4, 2},
9388 torch::TensorOptions(torch::kInt));
9389 torch::Tensor rhs = torch::randint(
9390 0,
9391 std::numeric_limits<int32_t>::max(),
9392 {4, 2},
9393 torch::TensorOptions(torch::kInt));
9394 ForEachDevice([&](const torch::Device& device) {
9395 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9396 torch::Tensor result = lhs.__ior__(rhs);
9397 torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9398 torch::Tensor lazy_result = lazy_lhs.__ior__(lazy_rhs);
9399 AllEqual(result, lazy_result);
9400 AllEqual(lhs, lazy_lhs);
9401 });
9402}
9403
9404TEST_F(LazyOpsTest, TestBitwiseOrScalar) {
9405 torch::Tensor lhs = torch::randint(
9406 0,
9407 std::numeric_limits<int32_t>::max(),
9408 {4, 2},
9409 torch::TensorOptions(torch::kInt));
9410 torch::Scalar rhs(123456789);
9411 torch::Tensor result = lhs.__or__(rhs);
9412 ForEachDevice([&](const torch::Device& device) {
9413 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9414 torch::Tensor lazy_result = lazy_lhs.__or__(rhs);
9415 AllEqual(result, lazy_result);
9416 });
9417}
9418
9419TEST_F(LazyOpsTest, TestBitwiseOrScalarInPlace) {
9420 torch::Tensor lhs = torch::randint(
9421 0,
9422 std::numeric_limits<int32_t>::max(),
9423 {4, 2},
9424 torch::TensorOptions(torch::kInt));
9425 torch::Scalar rhs(123456789);
9426 ForEachDevice([&](const torch::Device& device) {
9427 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9428 torch::Tensor result = lhs.__ior__(rhs);
9429 torch::Tensor lazy_result = lazy_lhs.__ior__(rhs);
9430 AllEqual(result, lazy_result);
9431 AllEqual(lhs, lazy_lhs);
9432 });
9433}
9434
9435TEST_F(LazyOpsTest, TestBitwiseXor) {
9436 torch::Tensor lhs = torch::randint(
9437 0,
9438 std::numeric_limits<int32_t>::max(),
9439 {4, 2},
9440 torch::TensorOptions(torch::kInt));
9441 torch::Tensor rhs = torch::randint(
9442 0,
9443 std::numeric_limits<int32_t>::max(),
9444 {4, 2},
9445 torch::TensorOptions(torch::kInt));
9446 torch::Tensor result = lhs.__xor__(rhs);
9447 ForEachDevice([&](const torch::Device& device) {
9448 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9449 torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9450 torch::Tensor lazy_result = lazy_lhs.__xor__(lazy_rhs);
9451 AllEqual(result, lazy_result);
9452 });
9453}
9454
9455TEST_F(LazyOpsTest, TestBitwiseXorInPlace) {
9456 torch::Tensor lhs = torch::randint(
9457 0,
9458 std::numeric_limits<int32_t>::max(),
9459 {4, 2},
9460 torch::TensorOptions(torch::kInt));
9461 torch::Tensor rhs = torch::randint(
9462 0,
9463 std::numeric_limits<int32_t>::max(),
9464 {4, 2},
9465 torch::TensorOptions(torch::kInt));
9466 ForEachDevice([&](const torch::Device& device) {
9467 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9468 torch::Tensor result = lhs.__ixor__(rhs);
9469 torch::Tensor lazy_rhs = CopyToDevice(rhs, device);
9470 torch::Tensor lazy_result = lazy_lhs.__ixor__(lazy_rhs);
9471 AllEqual(result, lazy_result);
9472 AllEqual(lhs, lazy_lhs);
9473 });
9474}
9475
9476TEST_F(LazyOpsTest, TestBitwiseXorScalar) {
9477 torch::Tensor lhs = torch::randint(
9478 0,
9479 std::numeric_limits<int32_t>::max(),
9480 {4, 2},
9481 torch::TensorOptions(torch::kInt));
9482 torch::Scalar rhs(123456789);
9483 torch::Tensor result = lhs.__xor__(rhs);
9484 ForEachDevice([&](const torch::Device& device) {
9485 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9486 torch::Tensor lazy_result = lazy_lhs.__xor__(rhs);
9487 AllEqual(result, lazy_result);
9488 });
9489}
9490
9491TEST_F(LazyOpsTest, TestBitwiseXorScalarInPlace) {
9492 torch::Tensor lhs = torch::randint(
9493 0,
9494 std::numeric_limits<int32_t>::max(),
9495 {4, 2},
9496 torch::TensorOptions(torch::kInt));
9497 torch::Scalar rhs(123456789);
9498 ForEachDevice([&](const torch::Device& device) {
9499 torch::Tensor lazy_lhs = CopyToDevice(lhs, device);
9500 torch::Tensor result = lhs.__ixor__(rhs);
9501 torch::Tensor lazy_result = lazy_lhs.__ixor__(rhs);
9502 AllEqual(result, lazy_result);
9503 AllEqual(lhs, lazy_lhs);
9504 });
9505}
9506
9507TEST_F(LazyOpsTest, TestLshift) {
9508 torch::Tensor input = torch::ones(
9509 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9510 torch::Tensor shift_amount = torch::randint(
9511 16,
9512 input.sizes(),
9513 torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9514 torch::Tensor result = torch::__lshift__(input, shift_amount);
9515 ForEachDevice([&](const torch::Device& device) {
9516 torch::Tensor lazy_input = CopyToDevice(input, device);
9517 torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9518 torch::Tensor lazy_result =
9519 torch::__lshift__(lazy_input, lazy_shift_amount);
9520 AllClose(result, lazy_result);
9521 });
9522}
9523
9524TEST_F(LazyOpsTest, TestLshiftInPlace) {
9525 torch::Tensor input = torch::ones(
9526 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9527 ForEachDevice([&](const torch::Device& device) {
9528 torch::Tensor lazy_input = CopyToDevice(input, device);
9529 torch::Tensor shift_amount = torch::randint(
9530 16,
9531 input.sizes(),
9532 torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9533 torch::Tensor result = input.__ilshift__(shift_amount);
9534 torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9535 torch::Tensor lazy_result = lazy_input.__ilshift__(lazy_shift_amount);
9536 AllClose(result, lazy_result);
9537 AllClose(input, lazy_input);
9538 });
9539}
9540
9541TEST_F(LazyOpsTest, TestLshiftScalar) {
9542 torch::Tensor input = torch::ones(
9543 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9544 torch::Scalar shift_amount = 3;
9545 torch::Tensor result = torch::__lshift__(input, shift_amount);
9546 ForEachDevice([&](const torch::Device& device) {
9547 torch::Tensor lazy_input = CopyToDevice(input, device);
9548 torch::Tensor lazy_result = torch::__lshift__(lazy_input, shift_amount);
9549 AllClose(result, lazy_result);
9550 });
9551}
9552
9553TEST_F(LazyOpsTest, TestLshiftScalarInPlace) {
9554 torch::Tensor input = torch::ones(
9555 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9556 torch::Scalar shift_amount = 3;
9557 ForEachDevice([&](const torch::Device& device) {
9558 torch::Tensor lazy_input = CopyToDevice(input, device);
9559 torch::Tensor result = input.__ilshift__(shift_amount);
9560 torch::Tensor lazy_result = lazy_input.__ilshift__(shift_amount);
9561 AllClose(result, lazy_result);
9562 AllClose(input, lazy_input);
9563 });
9564}
9565
9566TEST_F(LazyOpsTest, TestRshift) {
9567 torch::Tensor input = torch::ones(
9568 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9569 torch::Tensor shift_amount = torch::randint(
9570 16,
9571 input.sizes(),
9572 torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9573 torch::Tensor result = torch::__rshift__(input, shift_amount);
9574 ForEachDevice([&](const torch::Device& device) {
9575 torch::Tensor lazy_input = CopyToDevice(input, device);
9576 torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9577 torch::Tensor lazy_result =
9578 torch::__rshift__(lazy_input, lazy_shift_amount);
9579 AllClose(result, lazy_result);
9580 });
9581}
9582
9583TEST_F(LazyOpsTest, TestRshiftInPlace) {
9584 torch::Tensor input = torch::ones(
9585 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9586 ForEachDevice([&](const torch::Device& device) {
9587 torch::Tensor lazy_input = CopyToDevice(input, device);
9588 torch::Tensor shift_amount = torch::randint(
9589 16,
9590 input.sizes(),
9591 torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9592 torch::Tensor result = input.__irshift__(shift_amount);
9593 torch::Tensor lazy_shift_amount = CopyToDevice(shift_amount, device);
9594 torch::Tensor lazy_result = lazy_input.__irshift__(lazy_shift_amount);
9595 AllClose(result, lazy_result);
9596 AllClose(input, lazy_input);
9597 });
9598}
9599
9600TEST_F(LazyOpsTest, TestRshiftScalar) {
9601 torch::Tensor input = torch::ones(
9602 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9603 torch::Scalar shift_amount = 3;
9604 torch::Tensor result = torch::__rshift__(input, shift_amount);
9605 ForEachDevice([&](const torch::Device& device) {
9606 torch::Tensor lazy_input = CopyToDevice(input, device);
9607 torch::Tensor lazy_result = torch::__rshift__(lazy_input, shift_amount);
9608 AllClose(result, lazy_result);
9609 });
9610}
9611
9612TEST_F(LazyOpsTest, TestRshiftScalarInPlace) {
9613 torch::Tensor input = torch::ones(
9614 {4, 2}, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
9615 torch::Scalar shift_amount = 3;
9616 ForEachDevice([&](const torch::Device& device) {
9617 torch::Tensor lazy_input = CopyToDevice(input, device);
9618 torch::Tensor result = input.__irshift__(shift_amount);
9619 torch::Tensor lazy_result = lazy_input.__irshift__(shift_amount);
9620 AllClose(result, lazy_result);
9621 AllClose(input, lazy_input);
9622 });
9623}
9624
9625TEST_F(LazyOpsTest, TestMeshgrid) {
9626 torch::Tensor a = torch::rand(
9627 {3}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9628 torch::Tensor b = torch::rand(
9629 {2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9630 torch::Tensor c = torch::rand(
9631 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9632 auto d = torch::meshgrid({a, b, c});
9633 ForEachDevice([&](const torch::Device& device) {
9634 torch::Tensor lazy_a = CopyToDevice(a, device);
9635 torch::Tensor lazy_b = CopyToDevice(b, device);
9636 torch::Tensor lazy_c = CopyToDevice(c, device);
9637 auto lazy_d = torch::meshgrid({lazy_a, lazy_b, lazy_c});
9638 EXPECT_EQ(d.size(), lazy_d.size());
9639 for (size_t i = 0; i < d.size(); ++i) {
9640 AllClose(d[i], lazy_d[i]);
9641 }
9642 });
9643}
9644
9645TEST_F(LazyOpsTest, TestConstantPad) {
9646 torch::Tensor input = torch::rand(
9647 {4, 2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9648 std::vector<int64_t> pad{1, 2, 3, 4, 5, 6};
9649 float pad_value = 5;
9650 torch::Tensor output = torch::constant_pad_nd(input, pad, pad_value);
9651 ForEachDevice([&](const torch::Device& device) {
9652 torch::Tensor lazy_input = CopyToDevice(input, device);
9653 torch::Tensor lazy_output =
9654 torch::constant_pad_nd(lazy_input, pad, pad_value);
9655 AllClose(output, lazy_output);
9656 });
9657}
9658
9659TEST_F(LazyOpsTest, TestConstantPadIncomplete) {
9660 torch::Tensor input = torch::rand(
9661 {4, 2, 5}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9662 std::vector<int64_t> pad{1, 2};
9663 float pad_value = 5;
9664 torch::Tensor output = torch::constant_pad_nd(input, pad, pad_value);
9665 ForEachDevice([&](const torch::Device& device) {
9666 torch::Tensor lazy_input = CopyToDevice(input, device);
9667 torch::Tensor lazy_output =
9668 torch::constant_pad_nd(lazy_input, pad, pad_value);
9669 AllClose(output, lazy_output);
9670 });
9671}
9672
9673TEST_F(LazyOpsTest, TestReflectionPad2dRank3) {
9674 torch::Tensor input = torch::rand(
9675 {2, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9676 std::vector<int64_t> pad{2, 2, 2, 2};
9677 torch::Tensor output = torch::reflection_pad2d(input, pad);
9678 ForEachDevice([&](const torch::Device& device) {
9679 torch::Tensor lazy_input = CopyToDevice(input, device);
9680 torch::Tensor lazy_output = torch::reflection_pad2d(lazy_input, pad);
9681 AllClose(output, lazy_output);
9682 });
9683}
9684
9685TEST_F(LazyOpsTest, TestReflectionPad2dRank4) {
9686 torch::Tensor input = torch::rand(
9687 {2, 2, 3, 4},
9688 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9689 std::vector<int64_t> pad{2, 2, 2, 2};
9690 torch::Tensor output = torch::reflection_pad2d(input, pad);
9691 ForEachDevice([&](const torch::Device& device) {
9692 torch::Tensor lazy_input = CopyToDevice(input, device);
9693 torch::Tensor lazy_output = torch::reflection_pad2d(lazy_input, pad);
9694 AllClose(output, lazy_output);
9695 });
9696}
9697
9698TEST_F(LazyOpsTest, TestReflectionPad2dBackward) {
9699 std::vector<int64_t> pad{2, 3, 1, 2};
9700 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9701 return torch::reflection_pad2d(inputs[0], pad);
9702 };
9703 ForEachDevice([&](const torch::Device& device) {
9704 TestBackward(
9705 {torch::rand(
9706 {1, 2, 4, 4},
9707 torch::TensorOptions(torch::kFloat)
9708 .device(DefaultDevice())
9709 .requires_grad(true))},
9710 device,
9711 testfn);
9712 });
9713}
9714
9715TEST_F(LazyOpsTest, TestReplicationPad1d) {
9716 torch::Tensor input = torch::rand(
9717 {1, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9718 std::vector<int64_t> pad{1, 2};
9719 torch::Tensor output = torch::replication_pad1d(input, pad);
9720 ForEachDevice([&](const torch::Device& device) {
9721 torch::Tensor lazy_input = CopyToDevice(input, device);
9722 torch::Tensor lazy_output = torch::replication_pad1d(lazy_input, pad);
9723 AllClose(output, lazy_output);
9724 });
9725}
9726
9727TEST_F(LazyOpsTest, TestReplicationPad1dZeroPad) {
9728 torch::Tensor input = torch::rand(
9729 {1, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9730 std::vector<int64_t> pad{1, 0};
9731 torch::Tensor output = torch::replication_pad1d(input, pad);
9732 ForEachDevice([&](const torch::Device& device) {
9733 torch::Tensor lazy_input = CopyToDevice(input, device);
9734 torch::Tensor lazy_output = torch::replication_pad1d(lazy_input, pad);
9735 AllClose(output, lazy_output);
9736 });
9737}
9738
9739TEST_F(LazyOpsTest, TestReplicationPad1dBackward) {
9740 std::vector<int64_t> pad{2, 3};
9741 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9742 return torch::replication_pad1d(inputs[0], pad);
9743 };
9744 ForEachDevice([&](const torch::Device& device) {
9745 TestBackward(
9746 {torch::rand(
9747 {2, 4},
9748 torch::TensorOptions(torch::kFloat)
9749 .device(DefaultDevice())
9750 .requires_grad(true))},
9751 device,
9752 testfn);
9753 });
9754}
9755
9756TEST_F(LazyOpsTest, TestReplicationPad2d) {
9757 torch::Tensor input = torch::rand(
9758 {1, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9759 std::vector<int64_t> pad{1, 2, 2, 1};
9760 torch::Tensor output = torch::replication_pad2d(input, pad);
9761 ForEachDevice([&](const torch::Device& device) {
9762 torch::Tensor lazy_input = CopyToDevice(input, device);
9763 torch::Tensor lazy_output = torch::replication_pad2d(lazy_input, pad);
9764 AllClose(output, lazy_output);
9765 });
9766}
9767
9768TEST_F(LazyOpsTest, TestReplicationPad2dZeroPad) {
9769 torch::Tensor input = torch::rand(
9770 {1, 3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9771 std::vector<int64_t> pad{1, 0, 0, 1};
9772 torch::Tensor output = torch::replication_pad2d(input, pad);
9773 ForEachDevice([&](const torch::Device& device) {
9774 torch::Tensor lazy_input = CopyToDevice(input, device);
9775 torch::Tensor lazy_output = torch::replication_pad2d(lazy_input, pad);
9776 AllClose(output, lazy_output);
9777 });
9778}
9779
9780TEST_F(LazyOpsTest, TestReplicationPad2dBackward) {
9781 std::vector<int64_t> pad{2, 3, 1, 1};
9782 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9783 return torch::replication_pad2d(inputs[0], pad);
9784 };
9785 ForEachDevice([&](const torch::Device& device) {
9786 TestBackward(
9787 {torch::rand(
9788 {2, 3, 4},
9789 torch::TensorOptions(torch::kFloat)
9790 .device(DefaultDevice())
9791 .requires_grad(true))},
9792 device,
9793 testfn);
9794 });
9795}
9796
9797TEST_F(LazyOpsTest, TestAsStrided) {
9798 torch::Tensor input = torch::rand(
9799 {128, 320}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9800 std::vector<int64_t> size = {128, 20, 4, 4};
9801 std::vector<int64_t> stride = {320, 16, 4, 1};
9802 torch::Tensor output =
9803 torch::as_strided(input, /*size=*/size, /*stride=*/stride);
9804 ForEachDevice([&](const torch::Device& device) {
9805 torch::Tensor lazy_input = CopyToDevice(input, device);
9806 torch::Tensor lazy_output =
9807 torch::as_strided(lazy_input, /*size=*/size, /*stride=*/stride);
9808 AllClose(output, lazy_output);
9809 });
9810}
9811
9812TEST_F(LazyOpsTest, TestAsStridedInPlace) {
9813 torch::Tensor input = torch::rand(
9814 {128, 320}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9815 std::vector<int64_t> size = {128, 20, 4, 4};
9816 std::vector<int64_t> stride = {320, 16, 4, 1};
9817 ForEachDevice([&](const torch::Device& device) {
9818 torch::Tensor lazy_input = CopyToDevice(input, device);
9819 torch::Tensor output =
9820 torch::as_strided_(input, /*size=*/size, /*stride=*/stride);
9821 torch::Tensor lazy_output =
9822 torch::as_strided_(lazy_input, /*size=*/size, /*stride=*/stride);
9823 AllClose(output, lazy_output);
9824 AllClose(input, lazy_input);
9825 });
9826}
9827
9828TEST_F(LazyOpsTest, TestAsStridedWithOffset) {
9829 torch::Tensor input = torch::rand(
9830 {4, 8, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9831 std::vector<int64_t> size = {4, 4, 2};
9832 std::vector<int64_t> stride = {8, 2, 1};
9833 int64_t storage_offset = 4;
9834 torch::Tensor output = torch::as_strided(
9835 input,
9836 /*size=*/size,
9837 /*stride=*/stride,
9838 /*storage_offset=*/storage_offset);
9839 ForEachDevice([&](const torch::Device& device) {
9840 torch::Tensor lazy_input = CopyToDevice(input, device);
9841 torch::Tensor lazy_output = torch::as_strided(
9842 lazy_input,
9843 /*size=*/size,
9844 /*stride=*/stride,
9845 /*storage_offset=*/storage_offset);
9846 AllClose(output, lazy_output);
9847 });
9848}
9849
9850TEST_F(LazyOpsTest, TestAsStridedWithInplaceCopy) {
9851 torch::Tensor grad = torch::ones(
9852 {4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
9853 std::vector<int64_t> size = {4};
9854 std::vector<int64_t> stride = {1};
9855 torch::Tensor output = torch::zeros({4}, grad.options());
9856 output.as_strided(size, stride).copy_(grad);
9857 ForEachDevice([&](const torch::Device& device) {
9858 torch::Tensor lazy_grad = CopyToDevice(grad, device);
9859 torch::Tensor lazy_output = torch::zeros({4}, lazy_grad.options());
9860 lazy_output.as_strided(size, stride).copy_(lazy_grad);
9861 AllClose(output, lazy_output);
9862 });
9863}
9864
9865TEST_F(LazyOpsTest, TestEmptyStrided) {
9866 std::vector<int64_t> size = {4, 4, 2};
9867 std::vector<int64_t> stride = {8, 2, 1};
9868 torch::Tensor output = torch::empty_strided(/*size=*/size, /*stride=*/stride);
9869 ForEachDevice([&](const torch::Device& device) {
9870 torch::Tensor lazy_output =
9871 torch::empty_strided(/*size=*/size, /*stride=*/stride);
9872 EXPECT_EQ(output.sizes(), lazy_output.sizes());
9873 EXPECT_EQ(output.strides(), lazy_output.strides());
9874 });
9875}
9876
9877TEST_F(LazyOpsTest, TestAvgPool2DBackward) {
9878 int kernel_size = 2;
9879 for (int stride = 1; stride <= 2; ++stride) {
9880 for (int padding = 0; padding <= 1; ++padding) {
9881 for (bool count_include_pad : {true, false}) {
9882 // Test ceil_mode=true through the CPU interop.
9883 for (bool ceil_mode : {false, true}) {
9884 auto testfn =
9885 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9886 return torch::avg_pool2d(
9887 inputs[0],
9888 /*kernel_size=*/{kernel_size, kernel_size},
9889 /*stride=*/{stride, stride},
9890 /*padding=*/{padding, padding},
9891 /*ceil_mode=*/ceil_mode,
9892 /*count_include_pad=*/count_include_pad);
9893 };
9894
9895 ForEachDevice([&](const torch::Device& device) {
9896 TestBackward(
9897 {torch::rand(
9898 {1, 1, 7, 7},
9899 torch::TensorOptions(torch::kFloat)
9900 .device(DefaultDevice())
9901 .requires_grad(true))},
9902 device,
9903 testfn);
9904 });
9905 }
9906 }
9907 }
9908 }
9909}
9910
9911TEST_F(LazyOpsTest, TestAvgPool3DBackward) {
9912 int kernel_size = 2;
9913 for (int stride = 1; stride <= 2; ++stride) {
9914 for (int padding = 0; padding <= 1; ++padding) {
9915 for (bool count_include_pad : {true, false}) {
9916 // Test ceil_mode=true through the CPU interop.
9917 for (bool ceil_mode : {false, true}) {
9918 auto testfn =
9919 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9920 return torch::avg_pool3d(
9921 inputs[0],
9922 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
9923 /*stride=*/{stride, stride, stride},
9924 /*padding=*/{padding, padding, padding},
9925 /*ceil_mode=*/ceil_mode,
9926 /*count_include_pad=*/count_include_pad);
9927 };
9928
9929 ForEachDevice([&](const torch::Device& device) {
9930 TestBackward(
9931 {torch::rand(
9932 {1, 1, 7, 7, 7},
9933 torch::TensorOptions(torch::kFloat)
9934 .device(DefaultDevice())
9935 .requires_grad(true))},
9936 device,
9937 testfn);
9938 });
9939 }
9940 }
9941 }
9942 }
9943}
9944
9945TEST_F(LazyOpsTest, TestAvgPool2DNoBatchBackward) {
9946 int kernel_size = 2;
9947 for (int stride = 1; stride <= 2; ++stride) {
9948 for (int padding = 0; padding <= 1; ++padding) {
9949 for (bool count_include_pad : {true, false}) {
9950 // Test ceil_mode=true through the CPU interop.
9951 for (bool ceil_mode : {false, true}) {
9952 auto testfn =
9953 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9954 return torch::avg_pool2d(
9955 inputs[0],
9956 /*kernel_size=*/{kernel_size, kernel_size},
9957 /*stride=*/{stride, stride},
9958 /*padding=*/{padding, padding},
9959 /*ceil_mode=*/ceil_mode,
9960 /*count_include_pad=*/count_include_pad);
9961 };
9962
9963 ForEachDevice([&](const torch::Device& device) {
9964 TestBackward(
9965 {torch::rand(
9966 {1, 7, 7},
9967 torch::TensorOptions(torch::kFloat)
9968 .device(DefaultDevice())
9969 .requires_grad(true))},
9970 device,
9971 testfn);
9972 });
9973 }
9974 }
9975 }
9976 }
9977}
9978
9979TEST_F(LazyOpsTest, TestAvgPool3DNoBatchBackward) {
9980 int kernel_size = 2;
9981 for (int stride = 1; stride <= 2; ++stride) {
9982 for (int padding = 0; padding <= 1; ++padding) {
9983 for (bool count_include_pad : {true, false}) {
9984 // Test ceil_mode=true through the CPU interop.
9985 for (bool ceil_mode : {false, true}) {
9986 auto testfn =
9987 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
9988 return torch::avg_pool3d(
9989 inputs[0],
9990 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
9991 /*stride=*/{stride, stride, stride},
9992 /*padding=*/{padding, padding, padding},
9993 /*ceil_mode=*/ceil_mode,
9994 /*count_include_pad=*/count_include_pad);
9995 };
9996
9997 ForEachDevice([&](const torch::Device& device) {
9998 TestBackward(
9999 {torch::rand(
10000 {1, 7, 7, 7},
10001 torch::TensorOptions(torch::kFloat)
10002 .device(DefaultDevice())
10003 .requires_grad(true))},
10004 device,
10005 testfn);
10006 });
10007 }
10008 }
10009 }
10010 }
10011}
10012
10013TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DNoBatchBackward) {
10014 if (IsCuda()) {
10015 GTEST_SKIP();
10016 }
10017 for (int64_t output_size : {7, 4}) {
10018 auto testfn =
10019 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10020 return torch::adaptive_avg_pool3d(
10021 inputs[0], {output_size, output_size, output_size});
10022 };
10023 ForEachDevice([&](const torch::Device& device) {
10024 TestBackward(
10025 {torch::rand(
10026 {1, 56, 28, 28},
10027 torch::TensorOptions(torch::kFloat)
10028 .device(DefaultDevice())
10029 .requires_grad(true))},
10030 device,
10031 testfn);
10032 });
10033 }
10034}
10035
10036TEST_F(LazyOpsTest, TestAdaptiveAvgPool3DBackward) {
10037 if (IsCuda()) {
10038 GTEST_SKIP();
10039 }
10040 for (int64_t output_size : {7, 4}) {
10041 auto testfn =
10042 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10043 return torch::adaptive_avg_pool3d(
10044 inputs[0], {output_size, output_size, output_size});
10045 };
10046 ForEachDevice([&](const torch::Device& device) {
10047 TestBackward(
10048 {torch::rand(
10049 {4, 1, 56, 28, 28},
10050 torch::TensorOptions(torch::kFloat)
10051 .device(DefaultDevice())
10052 .requires_grad(true))},
10053 device,
10054 testfn);
10055 });
10056 }
10057}
10058
10059TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DBackward) {
10060 for (int64_t output_size : {7, 8}) {
10061 auto testfn =
10062 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10063 return torch::adaptive_avg_pool2d(inputs[0], {output_size, output_size});
10064 };
10065 ForEachDevice([&](const torch::Device& device) {
10066 TestBackward(
10067 {torch::rand(
10068 {4, 1, 56, 56},
10069 torch::TensorOptions(torch::kFloat)
10070 .device(DefaultDevice())
10071 .requires_grad(true))},
10072 device,
10073 testfn);
10074 });
10075 }
10076}
10077
10078TEST_F(LazyOpsTest, TestAdaptiveAvgPool2DNoBatchBackward) {
10079 for (int64_t output_size : {7, 8}) {
10080 auto testfn =
10081 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10082 return torch::adaptive_avg_pool2d(inputs[0], {output_size, output_size});
10083 };
10084 ForEachDevice([&](const torch::Device& device) {
10085 TestBackward(
10086 {torch::rand(
10087 {1, 56, 56},
10088 torch::TensorOptions(torch::kFloat).requires_grad(true))},
10089 device,
10090 testfn);
10091 });
10092 }
10093}
10094
10095TEST_F(LazyOpsTest, TestConv2D) {
10096 int in_channels = 4;
10097 int out_channels = 4;
10098 int kernel_size = 3;
10099 for (int stride = 1; stride <= 3; ++stride) {
10100 for (int padding = 0; padding <= 2; ++padding) {
10101 for (bool with_bias : {true, false}) {
10102 for (int dilation = 1; dilation <= 3; ++dilation) {
10103 for (int groups :
10104 {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10105 ForEachDevice([&](const torch::Device& device) {
10106 torch::Tensor input = torch::rand(
10107 {1, in_channels, 7, 7},
10108 torch::TensorOptions(torch::kDouble).device(DefaultDevice()));
10109 torch::Tensor weight = torch::rand(
10110 {out_channels,
10111 in_channels / groups,
10112 kernel_size,
10113 kernel_size},
10114 torch::TensorOptions(torch::kDouble).device(DefaultDevice()));
10115 torch::Tensor bias = with_bias
10116 ? torch::rand(
10117 {out_channels},
10118 torch::TensorOptions(torch::kDouble)
10119 .device(DefaultDevice()))
10120 : torch::Tensor();
10121
10122 torch::Tensor lazy_input = CopyToDevice(input, device);
10123 torch::Tensor lazy_weight = CopyToDevice(weight, device);
10124 torch::Tensor lazy_bias =
10125 with_bias ? CopyToDevice(bias, device) : torch::Tensor();
10126
10127 torch::Tensor output = torch::conv2d(
10128 input,
10129 weight,
10130 bias,
10131 /*stride=*/{stride, stride},
10132 /*padding=*/{padding, padding},
10133 /*dilation=*/{dilation, dilation},
10134 groups);
10135 torch::Tensor lazy_output = torch::conv2d(
10136 lazy_input,
10137 lazy_weight,
10138 lazy_bias,
10139 /*stride=*/{stride, stride},
10140 /*padding=*/{padding, padding},
10141 /*dilation=*/{dilation, dilation},
10142 groups);
10143 AllClose(output, lazy_output);
10144 });
10145 }
10146 }
10147 }
10148 }
10149 }
10150}
10151
10152TEST_F(LazyOpsTest, TestConv2DBackward) {
10153 int in_channels = 4;
10154 int out_channels = 4;
10155 int kernel_size = 3;
10156 for (int stride = 1; stride <= 3; ++stride) {
10157 for (int padding = 0; padding <= 2; ++padding) {
10158 for (bool with_bias : {true, false}) {
10159 for (int dilation = 1; dilation <= 3; ++dilation) {
10160 for (int groups :
10161 {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10162 auto testfn =
10163 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10164 return torch::conv2d(
10165 inputs[0],
10166 inputs[1],
10167 inputs[2],
10168 /*stride=*/{stride, stride},
10169 /*padding=*/{padding, padding},
10170 /*dilation=*/{dilation, dilation},
10171 groups);
10172 };
10173
10174 ForEachDevice([&](const torch::Device& device) {
10175 torch::Tensor bias = with_bias
10176 ? torch::rand(
10177 {out_channels},
10178 torch::TensorOptions(torch::kDouble)
10179 .device(DefaultDevice()))
10180 : torch::Tensor();
10181 TestBackward(
10182 {torch::rand(
10183 {1, in_channels, 7, 7},
10184 torch::TensorOptions(torch::kDouble)
10185 .device(DefaultDevice())
10186 .requires_grad(true)),
10187 torch::rand(
10188 {out_channels,
10189 in_channels / groups,
10190 kernel_size,
10191 kernel_size},
10192 torch::TensorOptions(torch::kDouble)
10193 .device(DefaultDevice())
10194 .requires_grad(true)),
10195 bias},
10196 device,
10197 testfn);
10198 });
10199 }
10200 };
10201 }
10202 }
10203 }
10204}
10205
10206TEST_F(LazyOpsTest, TestTransposedConv2DBackward) {
10207 int in_channels = 4;
10208 int out_channels = 4;
10209 int kernel_size = 3;
10210 for (int stride = 1; stride <= 2; ++stride) {
10211 for (int padding = 0; padding <= 1; ++padding) {
10212 for (int dilation = 1; dilation <= 2; ++dilation) {
10213 for (int output_padding = 0;
10214 output_padding < std::max(stride, dilation);
10215 ++output_padding) {
10216 for (bool with_bias : {true, false}) {
10217 for (int groups :
10218 {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10219 auto testfn = [&](const std::vector<torch::Tensor>& inputs)
10220 -> torch::Tensor {
10221 return torch::conv_transpose2d(
10222 inputs[0],
10223 inputs[1],
10224 inputs[2],
10225 /*stride=*/{stride, stride + 1},
10226 /*padding=*/{padding, padding + 1},
10227 /*output_padding=*/output_padding,
10228 /*groups=*/groups,
10229 /*dilation=*/{dilation, dilation + 1});
10230 };
10231 ForEachDevice([&](const torch::Device& device) {
10232 torch::Tensor input = torch::rand(
10233 {4, out_channels, 7, 7},
10234 torch::TensorOptions(torch::kFloat)
10235 .device(DefaultDevice())
10236 .requires_grad(true));
10237 torch::Tensor weight = torch::rand(
10238 {out_channels,
10239 in_channels / groups,
10240 kernel_size,
10241 kernel_size},
10242 torch::TensorOptions(torch::kFloat)
10243 .device(DefaultDevice())
10244 .requires_grad(true));
10245 torch::Tensor bias = with_bias
10246 ? torch::rand(
10247 {in_channels},
10248 torch::TensorOptions(torch::kFloat)
10249 .device(DefaultDevice())
10250 .requires_grad(true))
10251 : torch::Tensor();
10252 TestBackward(
10253 {input, weight, bias},
10254 device,
10255 testfn,
10256 /*rtol=*/1e-5,
10257 /*atol=*/1e-5);
10258 });
10259 }
10260 };
10261 }
10262 }
10263 }
10264 }
10265}
10266
10267TEST_F(LazyOpsTest, TestConv3DBackward) {
10268 int in_channels = 4;
10269 int out_channels = 4;
10270 int kernel_size = 3;
10271 for (int stride = 1; stride <= 3; ++stride) {
10272 for (int padding = 1; padding <= 2; ++padding) {
10273 for (bool with_bias : {true, false}) {
10274 for (int dilation = 1; dilation <= 2; ++dilation) {
10275 for (int groups :
10276 {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10277 auto testfn =
10278 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10279 return torch::conv3d(
10280 inputs[0],
10281 inputs[1],
10282 inputs[2],
10283 /*stride=*/{stride, stride, stride},
10284 /*padding=*/{padding, padding, padding},
10285 /*dilation=*/{dilation, dilation, dilation},
10286 groups);
10287 };
10288
10289 ForEachDevice([&](const torch::Device& device) {
10290 torch::Tensor bias = with_bias
10291 ? torch::rand(
10292 {out_channels},
10293 torch::TensorOptions(torch::kDouble)
10294 .device(DefaultDevice()))
10295 : torch::Tensor();
10296 TestBackward(
10297 {torch::rand(
10298 {4, in_channels, 7, 7, 7},
10299 torch::TensorOptions(torch::kDouble)
10300 .device(DefaultDevice())
10301 .requires_grad(true)),
10302 torch::rand(
10303 {out_channels,
10304 in_channels / groups,
10305 kernel_size,
10306 kernel_size,
10307 kernel_size},
10308 torch::TensorOptions(torch::kDouble)
10309 .device(DefaultDevice())
10310 .requires_grad(true)),
10311 bias},
10312 device,
10313 testfn);
10314 });
10315 }
10316 };
10317 }
10318 }
10319 }
10320}
10321
10322TEST_F(LazyOpsTest, TestTransposedConv3DBackward) {
10323 int in_channels = 4;
10324 int out_channels = 4;
10325 int kernel_size = 3;
10326 for (int stride = 1; stride <= 2; ++stride) {
10327 for (int padding = 0; padding <= 1; ++padding) {
10328 for (int dilation = 1; dilation <= 2; ++dilation) {
10329 for (int output_padding = 0;
10330 output_padding < std::max(stride, dilation);
10331 ++output_padding) {
10332 for (bool with_bias : {true, false}) {
10333 for (int groups :
10334 {1, 2, 4}) { // covers normal, grouped, depthwise conv.
10335 auto testfn = [&](const std::vector<torch::Tensor>& inputs)
10336 -> torch::Tensor {
10337 return torch::conv_transpose3d(
10338 inputs[0],
10339 inputs[1],
10340 inputs[2],
10341 /*stride=*/{stride, stride + 1, stride},
10342 /*padding=*/{padding, padding + 1, stride},
10343 /*output_padding=*/output_padding,
10344 /*groups=*/groups,
10345 /*dilation=*/{dilation, dilation + 1, dilation});
10346 };
10347 ForEachDevice([&](const torch::Device& device) {
10348 torch::Tensor input = torch::rand(
10349 {4, out_channels, 7, 7, 7},
10350 torch::TensorOptions(torch::kDouble)
10351 .device(DefaultDevice())
10352 .requires_grad(true));
10353 torch::Tensor weight = torch::rand(
10354 {out_channels,
10355 in_channels / groups,
10356 kernel_size,
10357 kernel_size,
10358 kernel_size},
10359 torch::TensorOptions(torch::kDouble)
10360 .device(DefaultDevice())
10361 .requires_grad(true));
10362 torch::Tensor bias = with_bias
10363 ? torch::rand(
10364 {in_channels},
10365 torch::TensorOptions(torch::kDouble)
10366 .device(DefaultDevice())
10367 .requires_grad(true))
10368 : torch::Tensor();
10369 TestBackward({input, weight, bias}, device, testfn);
10370 });
10371 }
10372 };
10373 }
10374 }
10375 }
10376 }
10377}
10378
10379TEST_F(LazyOpsTest, TestMaxPool2DBackward) {
10380 int kernel_size = 3;
10381 for (int stride = 1; stride <= 2; ++stride) {
10382 for (int padding = 0; padding <= 1; ++padding) {
10383 // Test ceil_mode=true through the CPU interop.
10384 for (bool ceil_mode : {false, true}) {
10385 auto testfn =
10386 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10387 return torch::max_pool2d(
10388 inputs[0],
10389 /*kernel_size=*/{kernel_size, kernel_size},
10390 /*stride=*/{stride, stride},
10391 /*padding=*/{padding, padding},
10392 /*dilation=*/{1, 1},
10393 /*ceil_mode=*/ceil_mode);
10394 };
10395
10396 ForEachDevice([&](const torch::Device& device) {
10397 TestBackward(
10398 {torch::rand(
10399 {1, 2, 8, 8},
10400 torch::TensorOptions(torch::kFloat)
10401 .device(DefaultDevice())
10402 .requires_grad(true))},
10403 device,
10404 testfn);
10405 });
10406 }
10407 }
10408 }
10409}
10410
10411TEST_F(LazyOpsTest, TestMaxPool3DBackward) {
10412 int kernel_size = 3;
10413 for (int stride = 1; stride <= 2; ++stride) {
10414 for (int padding = 0; padding <= 1; ++padding) {
10415 // Test ceil_mode=true through the CPU interop.
10416 for (bool ceil_mode : {false, true}) {
10417 auto testfn =
10418 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10419 return torch::max_pool3d(
10420 inputs[0],
10421 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
10422 /*stride=*/{stride, stride, stride},
10423 /*padding=*/{padding, padding, padding},
10424 /*dilation=*/{1, 1, 1},
10425 /*ceil_mode=*/ceil_mode);
10426 };
10427
10428 ForEachDevice([&](const torch::Device& device) {
10429 TestBackward(
10430 {torch::rand(
10431 {1, 2, 4, 4, 4},
10432 torch::TensorOptions(torch::kFloat)
10433 .device(DefaultDevice())
10434 .requires_grad(true))},
10435 device,
10436 testfn);
10437 });
10438 }
10439 }
10440 }
10441}
10442
10443TEST_F(LazyOpsTest, TestMaxPool2DNoBatchBackward) {
10444 int kernel_size = 3;
10445 for (int stride = 1; stride <= 2; ++stride) {
10446 for (int padding = 0; padding <= 1; ++padding) {
10447 // Test ceil_mode=true through the CPU interop.
10448 for (bool ceil_mode : {false, true}) {
10449 auto testfn =
10450 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10451 return torch::max_pool2d(
10452 inputs[0],
10453 /*kernel_size=*/{kernel_size, kernel_size},
10454 /*stride=*/{stride, stride},
10455 /*padding=*/{padding, padding},
10456 /*dilation=*/{1, 1},
10457 /*ceil_mode=*/ceil_mode);
10458 };
10459
10460 ForEachDevice([&](const torch::Device& device) {
10461 TestBackward(
10462 {torch::rand(
10463 {2, 8, 8},
10464 torch::TensorOptions(torch::kFloat)
10465 .device(DefaultDevice())
10466 .requires_grad(true))},
10467 device,
10468 testfn);
10469 });
10470 }
10471 }
10472 }
10473}
10474
10475TEST_F(LazyOpsTest, TestMaxPool3DNoBatchBackward) {
10476 int kernel_size = 3;
10477 for (int stride = 1; stride <= 2; ++stride) {
10478 for (int padding = 0; padding <= 1; ++padding) {
10479 // Test ceil_mode=true through the CPU interop.
10480 for (bool ceil_mode : {false, true}) {
10481 auto testfn =
10482 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10483 return torch::max_pool3d(
10484 inputs[0],
10485 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
10486 /*stride=*/{stride, stride, stride},
10487 /*padding=*/{padding, padding, padding},
10488 /*dilation=*/{1, 1, 1},
10489 /*ceil_mode=*/ceil_mode);
10490 };
10491
10492 ForEachDevice([&](const torch::Device& device) {
10493 TestBackward(
10494 {torch::rand(
10495 {2, 4, 4, 4},
10496 torch::TensorOptions(torch::kFloat)
10497 .device(DefaultDevice())
10498 .requires_grad(true))},
10499 device,
10500 testfn);
10501 });
10502 }
10503 }
10504 }
10505}
10506
10507TEST_F(LazyOpsTest, TestMaxUnpool2DBackward) {
10508 int kernel_size = 2;
10509 torch::Tensor input = torch::rand(
10510 {2, 2, 8, 8},
10511 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
10512 for (int stride = 1; stride <= 2; ++stride) {
10513 for (int padding = 0; padding <= 1; ++padding) {
10514 // Test ceil_mode=true through the CPU interop.
10515 for (bool ceil_mode : {false, true}) {
10516 for (int dilation = 1; dilation <= 2; ++dilation) {
10517 torch::Tensor output;
10518 torch::Tensor indices;
10519 std::tie(output, indices) = torch::max_pool2d_with_indices(
10520 input,
10521 /*kernel_size=*/{kernel_size, kernel_size},
10522 /*stride=*/{stride, stride},
10523 /*padding=*/{padding, padding},
10524 /*dilation=*/{dilation, dilation},
10525 /*ceil_mode=*/ceil_mode);
10526
10527 std::vector<int64_t> output_size({input.size(2), input.size(3)});
10528 auto testfn =
10529 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10530 return torch::max_unpool2d(inputs[0], inputs[1], output_size);
10531 };
10532
10533 ForEachDevice([&](const torch::Device& device) {
10534 TestBackward(
10535 {output.requires_grad_(true), indices}, device, testfn);
10536 });
10537 }
10538 }
10539 }
10540 }
10541}
10542
10543TEST_F(LazyOpsTest, TestMaxUnpool3DBackward) {
10544 int kernel_size = 2;
10545 torch::Tensor input = torch::rand(
10546 {1, 1, 4, 4, 4},
10547 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
10548 for (int stride = 1; stride <= 2; ++stride) {
10549 for (int padding = 0; padding <= 1; ++padding) {
10550 // Test ceil_mode=true through the CPU interop.
10551 for (bool ceil_mode : {false, true}) {
10552 for (int dilation = 1; dilation <= 2; ++dilation) {
10553 torch::Tensor output;
10554 torch::Tensor indices;
10555 std::tie(output, indices) = torch::max_pool3d_with_indices(
10556 input,
10557 /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
10558 /*stride=*/{stride, stride, stride},
10559 /*padding=*/{padding, padding, padding},
10560 /*dilation=*/{dilation, dilation, dilation},
10561 /*ceil_mode=*/ceil_mode);
10562
10563 std::vector<int64_t> output_size(
10564 {input.size(2), input.size(3), input.size(4)});
10565 auto testfn =
10566 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10567 return torch::max_unpool3d(
10568 inputs[0],
10569 inputs[1],
10570 output_size,
10571 /*stride=*/{stride, stride, stride},
10572 /*padding=*/{padding, padding, padding});
10573 };
10574
10575 ForEachDevice([&](const torch::Device& device) {
10576 TestBackward(
10577 {output.requires_grad_(true), indices}, device, testfn);
10578 });
10579 }
10580 }
10581 }
10582 }
10583}
10584
10585TEST_F(LazyOpsTest, TestTanhBackward) {
10586 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10587 return torch::tanh(inputs[0]);
10588 };
10589 ForEachDevice([&](const torch::Device& device) {
10590 TestBackward(
10591 {torch::rand(
10592 {2, 2},
10593 torch::TensorOptions(torch::kFloat)
10594 .device(DefaultDevice())
10595 .requires_grad(true))},
10596 device,
10597 testfn,
10598 /*rtol=*/1e-3,
10599 /*atol=*/1e-5);
10600 });
10601}
10602
10603TEST_F(LazyOpsTest, TestSigmoidBackward) {
10604 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10605 return torch::sigmoid(inputs[0]);
10606 };
10607 ForEachDevice([&](const torch::Device& device) {
10608 TestBackward(
10609 {torch::rand(
10610 {2, 2},
10611 torch::TensorOptions(torch::kFloat)
10612 .device(DefaultDevice())
10613 .requires_grad(true))},
10614 device,
10615 testfn);
10616 });
10617}
10618
10619TEST_F(LazyOpsTest, TestLogSigmoidBackward) {
10620 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10621 return torch::log_sigmoid(inputs[0]);
10622 };
10623 ForEachDevice([&](const torch::Device& device) {
10624 TestBackward(
10625 {torch::rand(
10626 {2, 2},
10627 torch::TensorOptions(torch::kFloat)
10628 .device(DefaultDevice())
10629 .requires_grad(true))},
10630 device,
10631 testfn,
10632 /*rtol=*/1e-3,
10633 /*atol=*/1e-5);
10634 });
10635}
10636
10637TEST_F(LazyOpsTest, TestLogSoftmaxBackward) {
10638 for (int dim = -4; dim < 4; ++dim) {
10639 auto testfn =
10640 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10641 return torch::log_softmax(inputs[0], dim);
10642 };
10643
10644 ForEachDevice([&](const torch::Device& device) {
10645 TestBackward(
10646 {torch::rand(
10647 {5, 3, 4, 2},
10648 torch::TensorOptions(torch::kFloat)
10649 .device(DefaultDevice())
10650 .requires_grad(true))},
10651 device,
10652 testfn,
10653 /*rtol=*/1e-3,
10654 /*atol=*/1e-4);
10655 });
10656 }
10657}
10658
10659TEST_F(LazyOpsTest, TestSoftmaxBackward) {
10660 for (int dim = -4; dim < 4; ++dim) {
10661 auto testfn =
10662 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10663 return torch::softmax(inputs[0], dim);
10664 };
10665
10666 ForEachDevice([&](const torch::Device& device) {
10667 TestBackward(
10668 {torch::rand(
10669 {5, 3, 4, 2},
10670 torch::TensorOptions(torch::kFloat)
10671 .device(DefaultDevice())
10672 .requires_grad(true))},
10673 device,
10674 testfn,
10675 /*rtol=*/1e-3,
10676 /*atol=*/1e-4);
10677 });
10678 }
10679}
10680
10681TEST_F(LazyOpsTest, TestSoftplusBackward) {
10682 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10683 return torch::softplus(inputs[0]);
10684 };
10685 ForEachDevice([&](const torch::Device& device) {
10686 TestBackward(
10687 {torch::rand(
10688 {2, 1, 4, 6},
10689 torch::TensorOptions(torch::kFloat)
10690 .device(DefaultDevice())
10691 .requires_grad(true))},
10692 device,
10693 testfn,
10694 /*rtol=*/1e-4);
10695 });
10696}
10697
10698TEST_F(LazyOpsTest, TestReluBackward) {
10699 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10700 return torch::relu(inputs[0]);
10701 };
10702 ForEachDevice([&](const torch::Device& device) {
10703 TestBackward(
10704 {torch::rand(
10705 {2, 1, 4, 6},
10706 torch::TensorOptions(torch::kFloat)
10707 .device(DefaultDevice())
10708 .requires_grad(true))},
10709 device,
10710 testfn);
10711 });
10712}
10713
10714TEST_F(LazyOpsTest, TestRreluBackward) {
10715 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10716 return torch::rrelu(inputs[0]);
10717 };
10718 ForEachDevice([&](const torch::Device& device) {
10719 TestBackward(
10720 {torch::rand(
10721 {2, 1, 4, 6},
10722 torch::TensorOptions(torch::kFloat)
10723 .device(DefaultDevice())
10724 .requires_grad(true))},
10725 device,
10726 testfn);
10727 });
10728}
10729
10730TEST_F(LazyOpsTest, TestHardshrinkBackward) {
10731 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10732 return torch::hardshrink(inputs[0]);
10733 };
10734 ForEachDevice([&](const torch::Device& device) {
10735 TestBackward(
10736 {torch::randn(
10737 {100},
10738 torch::TensorOptions(torch::kFloat)
10739 .device(DefaultDevice())
10740 .requires_grad(true))},
10741 device,
10742 testfn);
10743 });
10744}
10745
10746TEST_F(LazyOpsTest, TestSoftshrinkBackward) {
10747 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10748 return torch::softshrink(inputs[0]);
10749 };
10750 ForEachDevice([&](const torch::Device& device) {
10751 TestBackward(
10752 {torch::randn(
10753 {100},
10754 torch::TensorOptions(torch::kFloat)
10755 .device(DefaultDevice())
10756 .requires_grad(true))},
10757 device,
10758 testfn);
10759 });
10760}
10761
10762TEST_F(LazyOpsTest, TestHardtanhBackward) {
10763 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10764 return torch::hardtanh(inputs[0]);
10765 };
10766 ForEachDevice([&](const torch::Device& device) {
10767 TestBackward(
10768 {torch::randn(
10769 {100},
10770 torch::TensorOptions(torch::kFloat)
10771 .device(DefaultDevice())
10772 .requires_grad(true))},
10773 device,
10774 testfn);
10775 });
10776}
10777
10778TEST_F(LazyOpsTest, TestEluBackward) {
10779 torch::Scalar alpha = 0.5;
10780 torch::Scalar scale = 2.5;
10781 torch::Scalar input_scale = 1.5;
10782 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10783 return torch::elu(inputs[0], alpha, scale, input_scale);
10784 };
10785 ForEachDevice([&](const torch::Device& device) {
10786 TestBackward(
10787 {torch::rand(
10788 {2, 1, 4, 6},
10789 torch::TensorOptions(torch::kFloat)
10790 .device(DefaultDevice())
10791 .requires_grad(true))},
10792 device,
10793 testfn);
10794 });
10795}
10796
10797TEST_F(LazyOpsTest, TestGeluBackward) {
10798 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10799 return torch::gelu(inputs[0]);
10800 };
10801 ForEachDevice([&](const torch::Device& device) {
10802 TestBackward(
10803 {torch::rand(
10804 {2, 3},
10805 torch::TensorOptions(torch::kFloat)
10806 .device(DefaultDevice())
10807 .requires_grad(true))},
10808 device,
10809 testfn);
10810 });
10811 ExpectCounterChanged("lazy::gelu_backward", GetIgnoredCounters());
10812}
10813
10814TEST_F(LazyOpsTest, TestLeakyReluBackward) {
10815 double negative_slope = 0.01;
10816 auto testfn = [=](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10817 return torch::leaky_relu(inputs[0], negative_slope);
10818 };
10819 ForEachDevice([&](const torch::Device& device) {
10820 TestBackward(
10821 {torch::rand(
10822 {2, 1, 4, 6},
10823 torch::TensorOptions(torch::kFloat)
10824 .device(DefaultDevice())
10825 .requires_grad(true))},
10826 device,
10827 testfn);
10828 });
10829}
10830
10831TEST_F(LazyOpsTest, TestTransposeBackward) {
10832 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10833 return torch::t(inputs[0]);
10834 };
10835 ForEachDevice([&](const torch::Device& device) {
10836 TestBackward(
10837 {torch::rand(
10838 {2, 3},
10839 torch::TensorOptions(torch::kFloat)
10840 .device(DefaultDevice())
10841 .requires_grad(true))},
10842 device,
10843 testfn);
10844 });
10845}
10846
10847TEST_F(LazyOpsTest, TestAddMatMulBackward) {
10848 int in_channels = 32;
10849 int out_channels = 320;
10850 int labels = 50;
10851 // Test beta != 1. through the CPU interop.
10852 for (double beta : {1., 2.}) {
10853 auto testfn =
10854 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10855 return torch::addmm(inputs[0], inputs[1], inputs[2], /*beta=*/beta);
10856 };
10857 ForEachDevice([&](const torch::Device& device) {
10858 TestBackward(
10859 {torch::rand(
10860 {labels},
10861 torch::TensorOptions(torch::kFloat)
10862 .device(DefaultDevice())
10863 .requires_grad(true)),
10864 torch::rand(
10865 {in_channels, out_channels},
10866 torch::TensorOptions(torch::kFloat)
10867 .device(DefaultDevice())
10868 .requires_grad(true)),
10869 torch::rand(
10870 {out_channels, labels},
10871 torch::TensorOptions(torch::kFloat)
10872 .device(DefaultDevice())
10873 .requires_grad(true))},
10874 device,
10875 testfn);
10876 });
10877 }
10878}
10879
10880TEST_F(LazyOpsTest, TestBinaryCrossEntropyBackward) {
10881 int batch = 6;
10882 int classes = 2;
10883 // TODO(asuhan): Fix the torch::kDouble case.
10884 for (auto dtype : {torch::kFloat}) {
10885 for (bool def_weight : {false, true}) {
10886 torch::Tensor input = torch::rand(
10887 {batch, classes}, torch::TensorOptions(dtype).requires_grad(true));
10888 torch::Tensor target =
10889 torch::rand({batch, classes}, torch::TensorOptions(dtype));
10890 torch::Tensor weight;
10891 if (def_weight) {
10892 weight = torch::rand({batch, classes}, torch::TensorOptions(dtype));
10893 }
10894 for (torch::Reduction::Reduction reduction :
10895 {torch::Reduction::Mean,
10896 torch::Reduction::Sum,
10897 torch::Reduction::None}) {
10898 auto testfn =
10899 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10900 return torch::binary_cross_entropy(
10901 /*self=*/inputs[0],
10902 /*target=*/inputs[1],
10903 /*weight=*/inputs[2],
10904 /*reduction=*/reduction);
10905 };
10906 ForEachDevice([&](const torch::Device& device) {
10907 TestBackward(
10908 {input, target, weight},
10909 device,
10910 testfn,
10911 /*rtol=*/1e-4,
10912 /*atol=*/1e-7);
10913 });
10914 }
10915 }
10916 }
10917}
10918
10919TEST_F(LazyOpsTest, TestNllLossBackward) {
10920 // TODO(whc) debug divide-by-zero failure under ASAN
10921 GTEST_SKIP();
10922
10923 int batch = 6;
10924 int classes = 2;
10925 // TODO(asuhan): Fix the torch::kDouble case.
10926 for (auto dtype : {torch::kFloat}) {
10927 for (int ignore_index : {-1, 0, 1, 5}) {
10928 for (bool def_weight : {false, true}) {
10929 torch::Tensor input = torch::rand(
10930 {batch, classes},
10931 torch::TensorOptions(dtype)
10932 .device(DefaultDevice())
10933 .requires_grad(true));
10934 torch::Tensor target = torch::randint(
10935 std::min(ignore_index, 0),
10936 classes,
10937 {batch},
10938 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
10939 torch::Tensor weight;
10940 if (def_weight) {
10941 weight = torch::rand(
10942 {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
10943 }
10944 for (torch::Reduction::Reduction reduction :
10945 {torch::Reduction::Mean,
10946 torch::Reduction::Sum,
10947 torch::Reduction::None}) {
10948 auto testfn =
10949 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10950 return torch::nll_loss(
10951 /*self=*/inputs[0],
10952 /*target=*/inputs[1],
10953 /*weight=*/inputs[2],
10954 /*reduction=*/reduction,
10955 /*ignore_index=*/ignore_index);
10956 };
10957 ForEachDevice([&](const torch::Device& device) {
10958 TestBackward(
10959 {input, target, weight},
10960 device,
10961 testfn,
10962 /*rtol=*/1e-5,
10963 /*atol=*/1e-8);
10964 });
10965 }
10966 }
10967 }
10968 }
10969}
10970
10971TEST_F(LazyOpsTest, TestNllLoss2dBackward) {
10972 int batch = 6;
10973 int classes = 2;
10974 int height = 3;
10975 int width = 3;
10976 // TODO(asuhan): Fix the torch::kDouble case.
10977 for (auto dtype : {torch::kFloat}) {
10978 for (int ignore_index : {-1, 0, 1, 5}) {
10979 for (bool def_weight : {false, true}) {
10980 torch::Tensor input = torch::rand(
10981 {batch, classes, height, width},
10982 torch::TensorOptions(dtype)
10983 .device(DefaultDevice())
10984 .requires_grad(true));
10985 torch::Tensor target = torch::randint(
10986 std::min(ignore_index, 0),
10987 classes,
10988 {batch, height, width},
10989 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
10990 torch::Tensor weight;
10991 if (def_weight) {
10992 weight = torch::rand(
10993 {classes}, torch::TensorOptions(dtype).device(DefaultDevice()));
10994 }
10995 for (torch::Reduction::Reduction reduction :
10996 {torch::Reduction::Mean,
10997 torch::Reduction::Sum,
10998 torch::Reduction::None}) {
10999 auto testfn =
11000 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11001 return torch::nll_loss2d(
11002 /*self=*/inputs[0],
11003 /*target=*/inputs[1],
11004 /*weight=*/inputs[2],
11005 /*reduction=*/reduction,
11006 /*ignore_index=*/ignore_index);
11007 };
11008 ForEachDevice([&](const torch::Device& device) {
11009 TestBackward(
11010 {input, target, weight},
11011 device,
11012 testfn,
11013 /*rtol=*/1e-5,
11014 /*atol=*/1e-8);
11015 });
11016 }
11017 }
11018 }
11019 }
11020}
11021
11022TEST_F(LazyOpsTest, TestSmoothL1LossBackward) {
11023 torch::Tensor input = torch::randn(
11024 {2, 4},
11025 torch::TensorOptions(torch::kFloat)
11026 .device(DefaultDevice())
11027 .requires_grad(true));
11028 torch::Tensor target = torch::randn(
11029 {2, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11030 for (torch::Reduction::Reduction reduction :
11031 {torch::Reduction::None,
11032 torch::Reduction::Mean,
11033 torch::Reduction::Sum}) {
11034 for (double beta : {0.25, 1.}) {
11035 auto testfn =
11036 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11037 return torch::smooth_l1_loss(
11038 /*input=*/inputs[0],
11039 /*target=*/inputs[1],
11040 /*reduction=*/reduction,
11041 /*beta=*/beta);
11042 };
11043 ForEachDevice([&](const torch::Device& device) {
11044 TestBackward(
11045 {input, target},
11046 device,
11047 testfn,
11048 /*rtol=*/1e-5,
11049 /*atol=*/1e-8);
11050 });
11051 }
11052 }
11053}
11054
11055TEST_F(LazyOpsTest, TestViewBackward) {
11056 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11057 return inputs[0].view({-1, 320});
11058 };
11059 ForEachDevice([&](const torch::Device& device) {
11060 TestBackward(
11061 {torch::rand(
11062 {32, 20, 4, 4},
11063 torch::TensorOptions(torch::kFloat)
11064 .device(DefaultDevice())
11065 .requires_grad(true))},
11066 device,
11067 testfn);
11068 });
11069}
11070
11071TEST_F(LazyOpsTest, TestBatchNorm2DBackward) {
11072 double momentum = 0.1;
11073 double eps = 0.5;
11074 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11075 return torch::batch_norm(
11076 /*input=*/inputs[0],
11077 /*weight=*/inputs[1],
11078 /*bias=*/inputs[2],
11079 /*running_mean=*/inputs[3],
11080 /*running_var=*/inputs[4],
11081 /*training=*/true,
11082 /*momentum=*/momentum,
11083 /*eps=*/eps,
11084 /*cudnn_enabled=*/false);
11085 };
11086 int num_features = 3;
11087 torch::Tensor undef;
11088 for (bool undef_weight_bias : {false, true}) {
11089 ForEachDevice([&](const torch::Device& device) {
11090 torch::Tensor input = torch::rand(
11091 {2, num_features, 4, 4},
11092 torch::TensorOptions(torch::kFloat)
11093 .device(DefaultDevice())
11094 .requires_grad(true));
11095 torch::Tensor weight = undef_weight_bias
11096 ? undef
11097 : torch::rand(
11098 {num_features},
11099 torch::TensorOptions(torch::kFloat)
11100 .device(DefaultDevice())
11101 .requires_grad(true));
11102 torch::Tensor bias = undef_weight_bias
11103 ? undef
11104 : torch::rand(
11105 {num_features},
11106 torch::TensorOptions(torch::kFloat)
11107 .device(DefaultDevice())
11108 .requires_grad(true));
11109 torch::Tensor running_mean = torch::zeros(
11110 {num_features},
11111 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11112 torch::Tensor running_var = torch::ones(
11113 {num_features},
11114 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11115 TestBackward(
11116 {input, weight, bias, running_mean, running_var},
11117 device,
11118 testfn,
11119 /*rtol=*/1e-3,
11120 /*atol=*/1e-4);
11121 });
11122 }
11123}
11124
11125TEST_F(LazyOpsTest, TestBatchNorm3DBackward) {
11126 double momentum = 0.1;
11127 double eps = 0.5;
11128 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11129 return torch::batch_norm(
11130 /*input=*/inputs[0],
11131 /*weight=*/inputs[1],
11132 /*bias=*/inputs[2],
11133 /*running_mean=*/inputs[3],
11134 /*running_var=*/inputs[4],
11135 /*training=*/true,
11136 /*momentum=*/momentum,
11137 /*eps=*/eps,
11138 /*cudnn_enabled=*/false);
11139 };
11140 int num_features = 3;
11141 torch::Tensor undef;
11142 for (bool undef_weight_bias : {false, true}) {
11143 ForEachDevice([&](const torch::Device& device) {
11144 torch::Tensor input = torch::rand(
11145 {2, num_features, 4, 4, 2},
11146 torch::TensorOptions(torch::kFloat)
11147 .device(DefaultDevice())
11148 .requires_grad(true));
11149 torch::Tensor weight = undef_weight_bias
11150 ? undef
11151 : torch::rand(
11152 {num_features},
11153 torch::TensorOptions(torch::kFloat)
11154 .device(DefaultDevice())
11155 .requires_grad(true));
11156 torch::Tensor bias = undef_weight_bias
11157 ? undef
11158 : torch::rand(
11159 {num_features},
11160 torch::TensorOptions(torch::kFloat)
11161 .device(DefaultDevice())
11162 .requires_grad(true));
11163 torch::Tensor running_mean = torch::zeros(
11164 {num_features},
11165 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11166 torch::Tensor running_var = torch::ones(
11167 {num_features},
11168 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11169 TestBackward(
11170 {input, weight, bias, running_mean, running_var},
11171 device,
11172 testfn,
11173 /*rtol=*/1e-3,
11174 /*atol=*/1e-3);
11175 });
11176 }
11177}
11178
11179TEST_F(LazyOpsTest, TestBCEWithLogitsBackward) {
11180 int batch = 10;
11181 int classes = 5;
11182 torch::Tensor undef;
11183 for (torch::Reduction::Reduction reduction :
11184 {torch::Reduction::None,
11185 torch::Reduction::Mean,
11186 torch::Reduction::Sum}) {
11187 auto testfn =
11188 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11189 return torch::binary_cross_entropy_with_logits(
11190 /*input=*/inputs[0],
11191 /*target=*/inputs[1],
11192 /*weight=*/inputs[2],
11193 /*pos_weight=*/inputs[3],
11194 /*reduction=*/reduction);
11195 };
11196 for (bool undef_weight : {false, true}) {
11197 for (bool undef_pos_weight : {false, true}) {
11198 torch::Tensor input = torch::rand(
11199 {batch, classes},
11200 torch::TensorOptions(torch::kFloat)
11201 .device(DefaultDevice())
11202 .requires_grad(true));
11203 torch::Tensor target = torch::rand(
11204 {batch, classes},
11205 torch::TensorOptions(torch::kFloat)
11206 .device(DefaultDevice())
11207 .requires_grad(true));
11208 torch::Tensor weight = undef_weight
11209 ? undef
11210 : torch::rand(
11211 {classes},
11212 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11213 torch::Tensor pos_weight = undef_pos_weight
11214 ? undef
11215 : torch::rand(
11216 {classes},
11217 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11218 ForEachDevice([&](const torch::Device& device) {
11219 TestBackward(
11220 {input, target, weight, pos_weight},
11221 device,
11222 testfn,
11223 /*rtol=*/1e-3,
11224 /*atol=*/1e-5);
11225 });
11226 }
11227 }
11228 }
11229}
11230
11231TEST_F(LazyOpsTest, TestKlDivBackward) {
11232 torch::Tensor input = torch::rand(
11233 {4, 3},
11234 torch::TensorOptions(torch::kFloat)
11235 .device(DefaultDevice())
11236 .requires_grad(true));
11237 torch::Tensor target = torch::rand(
11238 {4, 3},
11239 torch::TensorOptions(torch::kFloat)
11240 .device(DefaultDevice())
11241 .requires_grad(true));
11242 for (torch::Reduction::Reduction reduction :
11243 {torch::Reduction::Mean,
11244 torch::Reduction::Sum,
11245 torch::Reduction::None}) {
11246 auto testfn =
11247 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11248 return torch::kl_div(/*self=*/inputs[0], /*target=*/inputs[1], reduction);
11249 };
11250 ForEachDevice([&](const torch::Device& device) {
11251 TestBackward(
11252 {input, target},
11253 device,
11254 testfn,
11255 /*rtol=*/1e-4,
11256 /*atol=*/1e-5);
11257 });
11258 }
11259}
11260
11261TEST_F(LazyOpsTest, TestEmbeddingBackward) {
11262 int num_weights = 32;
11263 for (int padding_idx = -1; padding_idx < num_weights; ++padding_idx) {
11264 for (bool scale_grad_by_freq : {false, true}) {
11265 auto testfn =
11266 [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
11267 return torch::embedding(
11268 inputs[0],
11269 inputs[1],
11270 /*padding_idx=*/padding_idx,
11271 /*scale_grad_by_freq=*/scale_grad_by_freq,
11272 /*sparse=*/false);
11273 };
11274 ForEachDevice([&](const torch::Device& device) {
11275 torch::Tensor weight = torch::rand(
11276 {num_weights, 7},
11277 torch::TensorOptions(torch::kFloat)
11278 .device(DefaultDevice())
11279 .requires_grad(true));
11280 torch::Tensor indices = torch::randint(
11281 num_weights,
11282 {3, 9, 4},
11283 torch::TensorOptions(torch::kLong).device(DefaultDevice()));
11284 TestBackward(
11285 {weight, indices},
11286 device,
11287 testfn,
11288 /*rtol=*/1e-5,
11289 /*atol=*/1e-8);
11290 });
11291 }
11292 }
11293}
11294
11295TEST_F(LazyOpsTest, TestAmpForeachNonFiniteCheckAndUnscale) {
11296 if (IsCuda()) {
11297 // TODO(whc) debug failure on cuda
11298 GTEST_SKIP();
11299 }
11300
11301 torch::Tensor grads0 = torch::tensor(
11302 {1, 2, 3, 4},
11303 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11304 torch::Tensor grads1 = torch::tensor(
11305 {1.0, 2.0, std::nan("1"), 4.0},
11306 torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11307 torch::Tensor inv_scale = torch::scalar_tensor(
11308 0.2, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11309 torch::Tensor found_inf = torch::scalar_tensor(
11310 0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11311 torch::Tensor grads_output0 = grads0 * inv_scale;
11312 torch::Tensor found_inf_output0 = torch::scalar_tensor(
11313 0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11314 torch::Tensor found_inf_output1 = torch::scalar_tensor(
11315 1, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11316 ForEachDevice([&](const torch::Device& device) {
11317 if (grads0.device() == at::kCPU) {
11318 GTEST_SKIP();
11319 }
11320 torch::Tensor lazy_grads0 = CopyToDevice(grads0, device);
11321 torch::Tensor lazy_inv_scale = CopyToDevice(inv_scale, device);
11322 torch::Tensor lazy_found_inf = CopyToDevice(found_inf, device);
11323 torch::_amp_foreach_non_finite_check_and_unscale_(
11324 lazy_grads0, lazy_found_inf, lazy_inv_scale);
11325 AllClose(grads_output0, lazy_grads0, /*rtol=*/1e-2, /*atol=*/1e-4);
11326 AllEqual(found_inf_output0, lazy_found_inf);
11327
11328 torch::Tensor lazy_grads1 = CopyToDevice(grads1, device);
11329 torch::_amp_foreach_non_finite_check_and_unscale_(
11330 lazy_grads1, lazy_found_inf, lazy_inv_scale);
11331 AllEqual(found_inf_output1, lazy_found_inf);
11332 });
11333}
11334
11335TEST_F(LazyOpsTest, TestAmpUpdateScale) {
11336 torch::Tensor growth_tracker = torch::scalar_tensor(
11337 0, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11338 torch::Tensor current_scale = torch::scalar_tensor(
11339 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11340 torch::Tensor found_inf = torch::scalar_tensor(
11341 1, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11342 torch::Tensor not_found_inf = torch::scalar_tensor(
11343 0, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11344 float scale_growth_factor = 2.0;
11345 float scale_backoff_factor = 0.5;
11346 int growth_interval = 3;
11347
11348 torch::Tensor growth_tracker_result0 = torch::scalar_tensor(
11349 1, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11350 torch::Tensor current_scale_result0 = torch::scalar_tensor(
11351 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11352 torch::Tensor growth_tracker_result1 = torch::scalar_tensor(
11353 2, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11354 torch::Tensor current_scale_result1 = torch::scalar_tensor(
11355 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11356 torch::Tensor growth_tracker_result2 = torch::scalar_tensor(
11357 0, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11358 torch::Tensor current_scale_result2 = torch::scalar_tensor(
11359 8, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11360 torch::Tensor growth_tracker_result3 = torch::scalar_tensor(
11361 0, torch::TensorOptions(torch::kInt32).device(DefaultDevice()));
11362 torch::Tensor current_scale_result3 = torch::scalar_tensor(
11363 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11364
11365 ForEachDevice([&](const torch::Device& device) {
11366 if (growth_tracker.device() == at::kCPU) {
11367 GTEST_SKIP();
11368 }
11369 torch::Tensor lazy_growth_tracker = CopyToDevice(growth_tracker, device);
11370 torch::Tensor lazy_current_scale = CopyToDevice(current_scale, device);
11371 torch::Tensor lazy_found_inf = CopyToDevice(found_inf, device);
11372 torch::Tensor lazy_not_found_inf = CopyToDevice(not_found_inf, device);
11373
11374 torch::_amp_update_scale_(
11375 lazy_current_scale,
11376 lazy_growth_tracker,
11377 lazy_not_found_inf,
11378 scale_growth_factor,
11379 scale_backoff_factor,
11380 growth_interval);
11381 AllClose(
11382 current_scale_result0,
11383 lazy_current_scale,
11384 /*rtol=*/1e-2,
11385 /*atol=*/1e-4);
11386 AllEqual(growth_tracker_result0, lazy_growth_tracker);
11387
11388 torch::_amp_update_scale_(
11389 lazy_current_scale,
11390 lazy_growth_tracker,
11391 lazy_not_found_inf,
11392 scale_growth_factor,
11393 scale_backoff_factor,
11394 growth_interval);
11395 AllClose(
11396 current_scale_result1,
11397 lazy_current_scale,
11398 /*rtol=*/1e-2,
11399 /*atol=*/1e-4);
11400 AllEqual(growth_tracker_result1, lazy_growth_tracker);
11401
11402 // torch::_amp_update_scale_ returns the reference of current_scale
11403 lazy_current_scale = torch::_amp_update_scale_(
11404 lazy_current_scale,
11405 lazy_growth_tracker,
11406 lazy_not_found_inf,
11407 scale_growth_factor,
11408 scale_backoff_factor,
11409 growth_interval);
11410 AllClose(
11411 current_scale_result2,
11412 lazy_current_scale,
11413 /*rtol=*/1e-2,
11414 /*atol=*/1e-4);
11415 AllEqual(growth_tracker_result2, lazy_growth_tracker);
11416
11417 lazy_current_scale = torch::_amp_update_scale_(
11418 lazy_current_scale,
11419 lazy_growth_tracker,
11420 lazy_found_inf,
11421 scale_growth_factor,
11422 scale_backoff_factor,
11423 growth_interval);
11424 AllClose(
11425 current_scale_result3,
11426 lazy_current_scale,
11427 /*rtol=*/1e-2,
11428 /*atol=*/1e-4);
11429 AllEqual(growth_tracker_result3, lazy_growth_tracker);
11430 });
11431 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11432 ExpectCounterChanged("lazy::_amp_update_scale_", GetIgnoredCounters());
11433}
11434
11435TEST_F(LazyOpsTest, TestEarlySyncLiveTensors) {
11436 torch::Tensor scalar_tensor = torch::scalar_tensor(
11437 1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11438 torch::Scalar scalar1 = scalar_tensor.item();
11439 ForEachDevice([&](const torch::Device& device) {
11440 torch::Tensor lazy_scalar_tensor = CopyToDevice(scalar_tensor, device);
11441 torch::Scalar scalar2 = lazy_scalar_tensor.item();
11442 ASSERT_EQ(scalar1.to<float>(), scalar2.to<float>());
11443 });
11444 if (DebugUtil::ExperimentEnabled("early_sync")) {
11445 ExpectCounterChanged("EarlySyncLiveTensorsCount", GetIgnoredCounters());
11446 } else {
11447 ExpectCounterNotChanged("EarlySyncLiveTensorsCount", GetIgnoredCounters());
11448 }
11449 ExpectCounterChanged("aten::_local_scalar_dense", GetIgnoredCounters());
11450}
11451
11452TEST_F(LazyOpsTest, TestLerp) {
11453 torch::Tensor start = torch::rand(
11454 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11455 torch::Tensor end = torch::rand(
11456 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11457 torch::Tensor weight = torch::rand(
11458 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11459 torch::Tensor res = torch::lerp(start, end, weight);
11460 ForEachDevice([&](const torch::Device& device) {
11461 torch::Tensor lazy_start = CopyToDevice(start, device);
11462 torch::Tensor lazy_end = CopyToDevice(end, device);
11463 torch::Tensor lazy_weight = CopyToDevice(weight, device);
11464 torch::Tensor lazy_res = torch::lerp(lazy_start, lazy_end, lazy_weight);
11465 AllClose(res, lazy_res);
11466 });
11467 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11468 ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11469}
11470
11471TEST_F(LazyOpsTest, TestLerpScalar) {
11472 torch::Tensor start = torch::rand(
11473 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11474 torch::Tensor end = torch::rand(
11475 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11476 torch::Scalar weight = torch::Scalar(3.0);
11477 torch::Tensor res = torch::lerp(start, end, weight);
11478 ForEachDevice([&](const torch::Device& device) {
11479 torch::Tensor lazy_start = CopyToDevice(start, device);
11480 torch::Tensor lazy_end = CopyToDevice(end, device);
11481 torch::Tensor lazy_res = torch::lerp(lazy_start, lazy_end, weight);
11482 AllClose(res, lazy_res);
11483 });
11484 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11485 ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11486}
11487
11488TEST_F(LazyOpsTest, TestLerpInplace) {
11489 torch::Tensor input = torch::rand(
11490 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11491 torch::Tensor end = torch::rand(
11492 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11493 torch::Tensor weight = torch::rand(
11494 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11495 torch::Tensor input_copy = input.clone();
11496 input.lerp_(end, weight);
11497 ForEachDevice([&](const torch::Device& device) {
11498 torch::Tensor lazy_input = CopyToDevice(input_copy, device);
11499 torch::Tensor lazy_end = CopyToDevice(end, device);
11500 torch::Tensor lazy_weight = CopyToDevice(weight, device);
11501 lazy_input.lerp_(lazy_end, lazy_weight);
11502 AllClose(lazy_input, input);
11503 });
11504 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11505 ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11506}
11507
11508TEST_F(LazyOpsTest, TestLerpScalarInplace) {
11509 torch::Tensor input = torch::rand(
11510 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11511 torch::Tensor end = torch::rand(
11512 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11513 torch::Scalar weight = torch::Scalar(3.0);
11514 torch::Tensor input_copy = input.clone();
11515 input.lerp_(end, weight);
11516 ForEachDevice([&](const torch::Device& device) {
11517 torch::Tensor lazy_input = CopyToDevice(input_copy, device);
11518 torch::Tensor lazy_end = CopyToDevice(end, device);
11519 lazy_input.lerp_(lazy_end, weight);
11520 AllClose(lazy_input, input);
11521 });
11522 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11523 ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11524}
11525
11526TEST_F(LazyOpsTest, TestLerpOut) {
11527 torch::Tensor start = torch::rand(
11528 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11529 torch::Tensor end = torch::rand(
11530 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11531 torch::Tensor weight = torch::rand(
11532 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11533 torch::Tensor res = torch::empty(
11534 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11535 ;
11536 torch::lerp_out(res, start, end, weight);
11537 ForEachDevice([&](const torch::Device& device) {
11538 torch::Tensor lazy_start = CopyToDevice(start, device);
11539 torch::Tensor lazy_end = CopyToDevice(end, device);
11540 torch::Tensor lazy_weight = CopyToDevice(weight, device);
11541 torch::Tensor lazy_res = torch::empty({3, 4}, lazy_start.options());
11542 torch::lerp_out(lazy_res, lazy_start, lazy_end, lazy_weight);
11543 AllClose(res, lazy_res);
11544 });
11545 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11546 ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11547}
11548
11549TEST_F(LazyOpsTest, TestLerpScalarOut) {
11550 torch::Tensor start = torch::rand(
11551 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11552 torch::Tensor end = torch::rand(
11553 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11554 torch::Scalar weight = torch::Scalar(3.0);
11555 torch::Tensor res = torch::empty(
11556 {3, 4}, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11557 torch::lerp_out(res, start, end, weight);
11558 ForEachDevice([&](const torch::Device& device) {
11559 torch::Tensor lazy_start = CopyToDevice(start, device);
11560 torch::Tensor lazy_end = CopyToDevice(end, device);
11561 torch::Tensor lazy_res = torch::empty({3, 4}, lazy_start.options());
11562 torch::lerp_out(lazy_res, lazy_start, lazy_end, weight);
11563 AllClose(res, lazy_res);
11564 });
11565 ExpectCounterNotChanged("aten::.*", GetIgnoredCounters());
11566 ExpectCounterChanged("lazy::lerp", GetIgnoredCounters());
11567}
11568
11569TEST_F(LazyOpsTest, IsAliasOf) {
11570 auto a = torch::empty(
11571 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11572 auto b = torch::empty(
11573 4, torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
11574
11575 ForEachDevice([&](const torch::Device& device) {
11576 auto lazy_a = CopyToDevice(a, device);
11577 auto lazy_b = CopyToDevice(b, device);
11578 EXPECT_EQ(!a.is_alias_of(b), !lazy_a.is_alias_of(lazy_b));
11579
11580 auto c = a.view({2, 2});
11581 auto lazy_c = lazy_a.view({2, 2});
11582 EXPECT_EQ(a.is_alias_of(c), lazy_a.is_alias_of(lazy_c));
11583
11584 auto d = c.view({1, 4});
11585 auto lazy_d = lazy_c.view({1, 4});
11586 EXPECT_EQ(d.is_alias_of(c), lazy_d.is_alias_of(lazy_c));
11587 EXPECT_EQ(d.is_alias_of(a), lazy_d.is_alias_of(lazy_a));
11588 });
11589}
11590
11591#endif // FBCODE_CAFFE2
11592
11593} // namespace lazy
11594} // namespace torch
11595