1 | #include <torch/csrc/lazy/ts_backend/tensor_aten_ops.h> |
2 | |
3 | #include <ATen/InferSize.h> |
4 | #include <c10/util/Optional.h> |
5 | #include <torch/csrc/autograd/variable.h> |
6 | #include <torch/csrc/lazy/core/helpers.h> |
7 | #include <torch/csrc/lazy/core/ir_builder.h> |
8 | #include <torch/csrc/lazy/core/ir_util.h> |
9 | #include <torch/csrc/lazy/core/lazy_graph_executor.h> |
10 | #include <torch/csrc/lazy/core/metrics.h> |
11 | #include <torch/csrc/lazy/core/ops/arithmetic_ir_ops.h> |
12 | #include <torch/csrc/lazy/core/ops/utils.h> |
13 | #include <torch/csrc/lazy/core/tensor.h> |
14 | #include <torch/csrc/lazy/core/util.h> |
15 | #include <torch/csrc/lazy/generated/LazyIr.h> |
16 | #include <torch/csrc/lazy/ts_backend/ops/random_ops.h> |
17 | #include <algorithm> |
18 | #include <functional> |
19 | |
20 | namespace torch { |
21 | namespace lazy { |
22 | namespace { |
23 | |
24 | // to enable operator+-*/ for Value |
25 | using namespace torch::lazy; |
26 | |
27 | torch::lazy::Value MaybeExpand( |
28 | const torch::lazy::Value& input, |
29 | const torch::lazy::Shape& target_shape) { |
30 | if (input.shape().sizes() == target_shape.sizes()) { |
31 | return input; |
32 | } |
33 | return torch::lazy::MakeExpand( |
34 | input, |
35 | target_shape.sizes().vec(), |
36 | /*is_scalar_expand=*/false); |
37 | } |
38 | |
39 | } // namespace |
40 | |
41 | ////////////////////////////////////////////////////////////////////////////// |
42 | // ATEN operators follows here, listed in alphabetical order. |
43 | ////////////////////////////////////////////////////////////////////////////// |
44 | |
45 | void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) { |
46 | torch::lazy::Value constant = |
47 | torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( |
48 | value, input->shape(), input->GetDevice()); |
49 | input->SetInPlaceIrValue(std::move(constant)); |
50 | } |
51 | |
52 | void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { |
53 | if (input->GetDevice() == src->GetDevice()) { |
54 | torch::lazy::Value copy_value; |
55 | if (input->dtype() == src->dtype()) { |
56 | copy_value = src->GetIrValue(); |
57 | } else { |
58 | copy_value = torch::lazy::MakeCast( |
59 | src->GetIrValue(), input->dtype(), src->dtype()); |
60 | } |
61 | input->SetIrValue(MaybeExpand(copy_value, input->shape())); |
62 | } else { |
63 | auto input_shape = input->shape(); |
64 | at::Tensor src_tensor = src->ToTensor(/*detached=*/true); |
65 | if (src_tensor.sizes() != input_shape.Get().sizes()) { |
66 | src_tensor = src_tensor.expand(input_shape.Get().sizes().vec()); |
67 | } |
68 | input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false); |
69 | } |
70 | } |
71 | |
72 | } // namespace lazy |
73 | } // namespace torch |
74 | |