1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_all_nodes.h>
6#include <transform_iter.h>
7
8#include <algorithm>
9#include <vector>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16// TODO: Only replay dispatch is really borrowed from TransformIter, we should
17// reevaluate the reuse of dispatch for classes that inherit TransformIter.
18class TORCH_CUDA_CU_API TransformRFactor {
19 public:
20 // Transform the provided tensor domain to two domains, a producer and
21 // consumer domain. These domains are created by taking axes and reducing them
22 // in the producer domain, and taking the remaining reduction axes and
23 // reducing them in the consumer domain.
24 static std::pair<TensorDomain*, TensorDomain*> runReplay(
25 TensorDomain*,
26 std::vector<int> axes);
27};
28
29} // namespace cuda
30} // namespace fuser
31} // namespace jit
32} // namespace torch
33