1 | #pragma once |
2 | |
3 | #include <ATen/core/Tensor.h> |
4 | #include <c10/core/Scalar.h> |
5 | |
6 | #ifndef AT_PER_OPERATOR_HEADERS |
7 | #include <ATen/Functions.h> |
8 | #else |
9 | #include <ATen/ops/empty_like.h> |
10 | #endif |
11 | |
12 | #include <stdexcept> |
13 | #include <string> |
14 | |
15 | namespace at { |
16 | |
17 | #define AT_FORALL_BINARY_OPS(_) \ |
18 | _(+, x.add(y), y.add(x)) \ |
19 | _(*, x.mul(y), y.mul(x)) \ |
20 | _(-, \ |
21 | x.sub(y), \ |
22 | ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \ |
23 | _(/, \ |
24 | x.div(y), \ |
25 | ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \ |
26 | _(%, \ |
27 | x.remainder(y), \ |
28 | ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \ |
29 | _(&, x.bitwise_and(y), y.bitwise_and(x)) \ |
30 | _(|, x.bitwise_or(y), y.bitwise_or(x)) \ |
31 | _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \ |
32 | _(<, x.lt(y), y.gt(x)) \ |
33 | _(<=, x.le(y), y.ge(x)) \ |
34 | _(>, x.gt(y), y.lt(x)) \ |
35 | _(>=, x.ge(y), y.le(x)) \ |
36 | _(==, x.eq(y), y.eq(x)) \ |
37 | _(!=, x.ne(y), y.ne(x)) |
38 | |
39 | #define DEFINE_OPERATOR(op, body, reverse_scalar_body) \ |
40 | static inline Tensor operator op(const Tensor& x, const Tensor& y) { \ |
41 | return body; \ |
42 | } \ |
43 | static inline Tensor operator op(const Tensor& x, const Scalar& y) { \ |
44 | return body; \ |
45 | } \ |
46 | static inline Tensor operator op(const Scalar& x, const Tensor& y) { \ |
47 | return reverse_scalar_body; \ |
48 | } |
49 | |
50 | AT_FORALL_BINARY_OPS(DEFINE_OPERATOR) |
51 | #undef DEFINE_OPERATOR |
52 | #undef AT_FORALL_BINARY_OPS |
53 | |
54 | } // namespace at |
55 | |