1#include <ATen/Operators.h>
2#include <ATen/native/CPUFallback.h>
3#include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h>
4#include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
5
6namespace torch {
7namespace lazy {
8
9at::Tensor MaxPool3dAutogradFunctionTS::forward(
10 torch::autograd::AutogradContext* ctx,
11 at::Tensor self,
12 at::IntArrayRef kernel_size,
13 at::IntArrayRef stride,
14 at::IntArrayRef padding,
15 at::IntArrayRef dilation,
16 bool ceil_mode) {
17 ctx->saved_data["kernel_size"] = kernel_size;
18 ctx->saved_data["stride"] = stride;
19 ctx->saved_data["padding"] = padding;
20 ctx->saved_data["dilation"] = dilation;
21 ctx->saved_data["ceil_mode"] = ceil_mode;
22 auto results = at::native::
23 call_fallback_fn<&ltc_eager_fallback, ATEN_OP(max_pool3d_with_indices)>::
24 call(self, kernel_size, stride, padding, dilation, ceil_mode);
25 ctx->save_for_backward({self, std::get<1>(results)});
26 return std::get<0>(results);
27}
28
29torch::autograd::variable_list MaxPool3dAutogradFunctionTS::backward(
30 torch::autograd::AutogradContext* ctx,
31 torch::autograd::variable_list grad_output) {
32 auto kernel_size = ctx->saved_data["kernel_size"].toIntList().vec();
33 auto stride = ctx->saved_data["stride"].toIntList().vec();
34 auto padding = ctx->saved_data["padding"].toIntList().vec();
35 auto dilation = ctx->saved_data["dilation"].toIntList().vec();
36 auto ceil_mode = ctx->saved_data["ceil_mode"].toBool();
37 auto saved = ctx->get_saved_variables();
38 auto self = saved[0];
39 at::Tensor grad;
40 auto indices = saved[1];
41 grad = at::native::call_fallback_fn<
42 &ltc_eager_fallback,
43 ATEN_OP(max_pool3d_with_indices_backward)>::
44 call(
45 grad_output[0],
46 self,
47 kernel_size,
48 stride,
49 padding,
50 dilation,
51 ceil_mode,
52 indices);
53
54 at::Tensor undef;
55 torch::autograd::variable_list grad_inputs = {
56 grad, undef, undef, undef, undef, undef};
57 return grad_inputs;
58}
59
60} // namespace lazy
61} // namespace torch
62