1#pragma once
2
3#include <ir_interface_nodes.h>
4#include <maxinfo_propagator.h>
5#include <transform_replay.h>
6
7#include <memory>
8#include <unordered_set>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15class MaxPosCalculator {
16 // Root domains in producer that's unmappable to any of its consumers
17 std::unordered_set<IterDomain*> unmappable_dims_;
18
19 // User set IterDomains to not inline, used in schedulers to avoid inlining
20 // trivial reductions
21 std::unordered_set<IterDomain*> uninlinable_ids_;
22
23 // Iterate through all TVs and collect the dimensions of each TV that don't
24 // map to all its consumer TVs.
25 void buildUnmappableDims();
26
27 // Utility function to return if an id of tv is a valid iter domain to inline
28 // within. This is used in getMaxPos{PasC,CasP}. Different variations of the
29 // bool values are used if checking max position of PasC, CasP, or checking
30 // for a max "self" position.
31 bool isAllowedID(
32 IterDomain* id,
33 TensorView* tv,
34 bool best_effort,
35 bool allow_reduction,
36 bool allow_vectorize,
37 bool allow_unmappable) const;
38
39 public:
40 // Returns the position at which tv can be inlined within.
41 size_t getMaxPosSelf(
42 TensorView* tv,
43 bool best_effort,
44 bool allow_reduction,
45 bool allow_vectorize,
46 bool allow_unmappable) const;
47
48 // Returns the maximum position producer can be inlined based on consumer
49 // given the set ComputeAtMode
50 size_t getMaxProducerPosFromConsumer(
51 TensorView* producer,
52 TensorView* consumer,
53 bool best_effort) const;
54
55 // Checks producers, consumers, and siblings to see what the maximum position
56 // in tv is that can be shared across both directions.
57 size_t getMaxPosAll(
58 TensorView* tv,
59 bool best_effort = false,
60 bool check_siblings = true);
61
62 MaxPosCalculator(const std::unordered_set<IterDomain*>& uninlinable_ids = {});
63};
64
65// Inline to the right most allowed position for all tensors in the current
66// fusion.
67TORCH_CUDA_CU_API void inlineMost(
68 const std::unordered_set<IterDomain*>& uninlinable_ids = {});
69// Inline to the right most allowed position for the selected tensors in the
70// current fusion.
71TORCH_CUDA_CU_API void inlineMost(
72 const std::vector<TensorView*>& tvs,
73 const std::unordered_set<IterDomain*>& uninlinable_ids = {});
74// Inline to the right most allowed position for the selected tensors in the
75// current fusion.
76TORCH_CUDA_CU_API void inlineMost(
77 const std::unordered_set<TensorView*>& tvs,
78 const std::unordered_set<IterDomain*>& uninlinable_ids = {});
79
80// Inline to the position corresponding to the reference position in the
81// reference tensor for all tensors in the current fusion.
82TORCH_CUDA_CU_API void inlineAllAt(
83 TensorView* reference_tv,
84 int64_t reference_pos,
85 bool best_effort = false,
86 const std::unordered_set<IterDomain*>& uninlinable_ids = {});
87
88// Inline to the position corresponding to the reference position in the
89// reference tensor for selected tensors in the current fusion.
90TORCH_CUDA_CU_API void inlineSelectedAt(
91 const std::unordered_set<TensorView*>& selected,
92 TensorView* reference_tv,
93 int64_t reference_pos,
94 bool best_effort = false,
95 const std::unordered_set<IterDomain*>& uninlinable_ids = {});
96
97} // namespace cuda
98} // namespace fuser
99} // namespace jit
100} // namespace torch
101