1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <dispatch.h>
6#include <ir_all_nodes.h>
7#include <kernel_ir.h>
8
9#include <unordered_set>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16//! Detect almost all IterDomains that are derived from trivial
17//! reductons.
18class TORCH_CUDA_CU_API TrivialReductionInfo {
19 public:
20 void build(Fusion* fusion);
21
22 bool isDerived(IterDomain* id) const;
23
24 private:
25 //! IterDomains that are derived only from trivial
26 //! reductons. Included domains are not limited to reduction axes as
27 //! rfactor can make reductions to normal axes.
28 //!
29 //! Note that the set should cover almost all cases but there can be
30 //! undetected trivial domains. For example, split by one creates a
31 //! trivial reduction domain, which is detected. However, if it is
32 //! further split, both of the two resulting axes are also trivial,
33 //! however, only the inner axis is recognized as trivial. While this
34 //! is a limitation, it would have very little practical
35 //! implication.
36 std::unordered_set<IterDomain*> domains_;
37 //! Subset of domains_, whose input root axes are all derived from
38 //! trivial reductions. These domains do not need to manifest as
39 //! for-loops.
40 std::unordered_set<IterDomain*> domains_derived_from_root_;
41};
42
43} // namespace cuda
44} // namespace fuser
45} // namespace jit
46} // namespace torch
47