1 | #pragma once |
2 | #include <ATen/ATen.h> |
3 | #include <ATen/core/op_registration/op_registration.h> |
4 | #include <torch/library.h> |
5 | |
6 | namespace at { |
7 | |
8 | // If an operator doesn't have a batching rule implemented then we fallback |
9 | // to this implementation. The fallback only works on out-of-place operators |
10 | // that return only tensors with new memory. (e.g., no in-place operators, no |
11 | // view operations). |
12 | // |
13 | // The fallback effectively takes all of the BatchedTensors in `stack`, slices |
14 | // them, and runs `op` on all of the corresponding slices to produce slices |
15 | // of the outputs. The output slices then get `torch.stack`ed to create the |
16 | // final returns. |
17 | // |
18 | // The performance of the fallback is not very good because it introduces an |
19 | // extra copy from stacking the sliced outputs. Because of this, we prefer to |
20 | // write batching rules for operators whenever possible. |
21 | void batchedTensorForLoopFallback( |
22 | const c10::OperatorHandle& op, |
23 | torch::jit::Stack* stack); |
24 | |
25 | } // namespace at |
26 | |