1 | #include <ATen/Operators.h> |
2 | #include <ATen/native/CPUFallback.h> |
3 | #include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h> |
4 | #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h> |
5 | |
6 | namespace torch { |
7 | namespace lazy { |
8 | |
9 | at::Tensor MaxPool3dAutogradFunctionTS::forward( |
10 | torch::autograd::AutogradContext* ctx, |
11 | at::Tensor self, |
12 | at::IntArrayRef kernel_size, |
13 | at::IntArrayRef stride, |
14 | at::IntArrayRef padding, |
15 | at::IntArrayRef dilation, |
16 | bool ceil_mode) { |
17 | ctx->saved_data["kernel_size" ] = kernel_size; |
18 | ctx->saved_data["stride" ] = stride; |
19 | ctx->saved_data["padding" ] = padding; |
20 | ctx->saved_data["dilation" ] = dilation; |
21 | ctx->saved_data["ceil_mode" ] = ceil_mode; |
22 | auto results = at::native:: |
23 | call_fallback_fn<<c_eager_fallback, ATEN_OP(max_pool3d_with_indices)>:: |
24 | call(self, kernel_size, stride, padding, dilation, ceil_mode); |
25 | ctx->save_for_backward({self, std::get<1>(results)}); |
26 | return std::get<0>(results); |
27 | } |
28 | |
29 | torch::autograd::variable_list MaxPool3dAutogradFunctionTS::backward( |
30 | torch::autograd::AutogradContext* ctx, |
31 | torch::autograd::variable_list grad_output) { |
32 | auto kernel_size = ctx->saved_data["kernel_size" ].toIntList().vec(); |
33 | auto stride = ctx->saved_data["stride" ].toIntList().vec(); |
34 | auto padding = ctx->saved_data["padding" ].toIntList().vec(); |
35 | auto dilation = ctx->saved_data["dilation" ].toIntList().vec(); |
36 | auto ceil_mode = ctx->saved_data["ceil_mode" ].toBool(); |
37 | auto saved = ctx->get_saved_variables(); |
38 | auto self = saved[0]; |
39 | at::Tensor grad; |
40 | auto indices = saved[1]; |
41 | grad = at::native::call_fallback_fn< |
42 | <c_eager_fallback, |
43 | ATEN_OP(max_pool3d_with_indices_backward)>:: |
44 | call( |
45 | grad_output[0], |
46 | self, |
47 | kernel_size, |
48 | stride, |
49 | padding, |
50 | dilation, |
51 | ceil_mode, |
52 | indices); |
53 | |
54 | at::Tensor undef; |
55 | torch::autograd::variable_list grad_inputs = { |
56 | grad, undef, undef, undef, undef, undef}; |
57 | return grad_inputs; |
58 | } |
59 | |
60 | } // namespace lazy |
61 | } // namespace torch |
62 | |