1#pragma once
2
3#include <ATen/core/Tensor.h>
4#include <c10/core/Scalar.h>
5
6#ifndef AT_PER_OPERATOR_HEADERS
7#include <ATen/Functions.h>
8#else
9#include <ATen/ops/empty_like.h>
10#endif
11
12#include <stdexcept>
13#include <string>
14
15namespace at {
16
17#define AT_FORALL_BINARY_OPS(_) \
18 _(+, x.add(y), y.add(x)) \
19 _(*, x.mul(y), y.mul(x)) \
20 _(-, \
21 x.sub(y), \
22 ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \
23 _(/, \
24 x.div(y), \
25 ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \
26 _(%, \
27 x.remainder(y), \
28 ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \
29 _(&, x.bitwise_and(y), y.bitwise_and(x)) \
30 _(|, x.bitwise_or(y), y.bitwise_or(x)) \
31 _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \
32 _(<, x.lt(y), y.gt(x)) \
33 _(<=, x.le(y), y.ge(x)) \
34 _(>, x.gt(y), y.lt(x)) \
35 _(>=, x.ge(y), y.le(x)) \
36 _(==, x.eq(y), y.eq(x)) \
37 _(!=, x.ne(y), y.ne(x))
38
39#define DEFINE_OPERATOR(op, body, reverse_scalar_body) \
40 static inline Tensor operator op(const Tensor& x, const Tensor& y) { \
41 return body; \
42 } \
43 static inline Tensor operator op(const Tensor& x, const Scalar& y) { \
44 return body; \
45 } \
46 static inline Tensor operator op(const Scalar& x, const Tensor& y) { \
47 return reverse_scalar_body; \
48 }
49
50AT_FORALL_BINARY_OPS(DEFINE_OPERATOR)
51#undef DEFINE_OPERATOR
52#undef AT_FORALL_BINARY_OPS
53
54} // namespace at
55