1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_all_nodes.h>
6#include <iter_visitor.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13//! If an IterDomain is split and its inner output domain is
14//! eventually split too, the second split must be divisible or the
15//! inner domain must be predicated. This class finds Split
16//! expressions that need to be divisible or predicated.
17//!
18//! Second splits are not limited to just direct output domains of
19//! first splits but also indirect descendent domains as well.
20//!
21//! Predicating non-divisible split domains does not work if split
22//! output domains are vectorized where ParallelType::Vectorize is
23//! applied to an inner domain of splits. If it's non-divisible,
24//! predicating the input domain of the non-divisible split results in
25//! a vectoried operation is predicated out entirely since we do not
26//! generate a fall-back non-vectorized else path. Runtime check is
27//! done for those domains.
28class TORCH_CUDA_CU_API NonDivisibleSplitInfo : public IterVisitor {
29 public:
30 void build(Fusion* fusion);
31
32 const auto& splitsToPredicate() const {
33 return splits_to_predicate_;
34 }
35
36 const auto& splitsToValidate() const {
37 return splits_to_validate_;
38 }
39
40 private:
41 using IterVisitor::handle;
42
43 void handle(Split* split) override;
44
45 void handle(Merge* merge) override;
46
47 //! True if reachable from inner domains of splits
48 bool isReachableFromInnerDomains(IterDomain* id) const;
49
50 //! Forward propagate the reachability information
51 void propagateReachability(Split* split, bool is_protected);
52
53 //! Forward propagate the reachability information
54 void propagateReachability(Merge* merge);
55
56 void clearReachability();
57
58 //! Returns the extent of a split output domain if it's not proven to
59 //! be divisible.
60 Val* getMaybeNonDivisibleExtent(Split* split) const;
61
62 //! Remove redundant predicates as divisibility may be validated at
63 //! run time
64 void removeRedundancy();
65
66 private:
67 //! Split expressions whose input domain must be predicated
68 std::unordered_map<TensorView*, std::vector<Split*>> splits_to_predicate_;
69 //! Split expressions whose divisibility must be validated at run time
70 std::unordered_set<Split*> splits_to_validate_;
71
72 //! Temporarily used for analyzing each tensor
73 TensorView* current_tv_ = nullptr;
74 std::unordered_set<IterDomain*> inner_domains_;
75};
76
77} // namespace cuda
78} // namespace fuser
79} // namespace jit
80} // namespace torch
81