1 | #pragma once |
2 | |
3 | #include <ir_all_nodes.h> |
4 | #include <kernel_ir.h> |
5 | |
6 | #include <vector> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | struct IndexFromIdGraph; |
14 | |
15 | //! Insert magic zero definition at the begining of the kernel. Insert magic |
16 | //! zero update after every (outer most) loop nest with a compile time extent. |
17 | //! |
18 | //! This will make sure nvrtc does not aggressively save predicate and indices. |
19 | std::vector<Expr*> insertMagicZero(const std::vector<Expr*>& exprs); |
20 | |
21 | //! Check if val is a reference to the magic zero variable |
22 | TORCH_CUDA_CU_API bool isMagicZero(const Val* val); |
23 | |
24 | //! Check if val is protected with magic zero. |
25 | //! |
26 | //! Specifically, this returns true if val is defined as "x + magic_zero". |
27 | bool isProtectedWithMagicZero(const Val* val); |
28 | |
29 | // Determine if we may run into over reuse of predicates or registers in the |
30 | // compiler. If the loop can be unrolled and the index and domain are not |
31 | // "simple" we likely want the loop protected. |
32 | // |
33 | // Magic zero protection should only be done for global memory and predicates. |
34 | // We should avoid use on registers. Shared memory does not require it, but |
35 | // likely wouldn't hurt. |
36 | bool needsMagicZero( |
37 | kir::ForLoop* loop, |
38 | IterDomain* reference_domain = nullptr, |
39 | Val* ind = nullptr); |
40 | |
41 | struct IndexMagicZeroInfo { |
42 | //! Index that may be updated with magic zero |
43 | Val* index = nullptr; |
44 | //! Loop index that is protected by magic zero. nullptr if no loop |
45 | //! is protected |
46 | Val* original_loop_index = nullptr; |
47 | //! Protected loop index. nullptr if no loop is protected |
48 | Val* protected_loop_index = nullptr; |
49 | //! Protected loop. nullptr if no loop is protected |
50 | IterDomain* loop_id = nullptr; |
51 | }; |
52 | |
53 | //! Protect an index val of an IterDomain with magic zero |
54 | //! |
55 | //! This should be only used for predicate indexing. |
56 | //! |
57 | //! No protection is done if none of the loops is determined to require |
58 | //! protection by needsMagicZero. |
59 | IndexMagicZeroInfo protectPredicateIndexWithMagicZero( |
60 | Val* index, |
61 | const IndexFromIdGraph& id_graph, |
62 | const std::vector<kir::ForLoop*>& loops); |
63 | |
64 | //! Protect an index val of a tensor with magic zero |
65 | //! |
66 | //! This should be only used for non-predicate indexing. |
67 | //! |
68 | //! No protection is done if none of the loops is determined to require |
69 | //! protection by needsMagicZero. |
70 | void protectNonPredicateIndexWithMagicZero( |
71 | const std::vector<kir::ForLoop*>& loops, |
72 | const std::vector<IterDomain*>& loop_domains, |
73 | std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map); |
74 | |
75 | } // namespace cuda |
76 | } // namespace fuser |
77 | } // namespace jit |
78 | } // namespace torch |
79 | |