1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_all_nodes.h> |
6 | #include <kernel_ir.h> |
7 | |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | //! Insert sync at end of for-loops to prevent write-after-read race condition. |
16 | //! |
17 | //! WAR race condition occurs when the next iteration of the loop overwrites |
18 | //! shared memory value before a previous operation has finished reading it. |
19 | std::vector<Expr*> insertWarThreadSynchronization( |
20 | const std::vector<Expr*>& exprs); |
21 | |
22 | //! Insert syncs between writing to shared memory and then reading it. |
23 | //! RAW pass is run before indexing, unrolling (loop duplication), memory |
24 | //! aliasing, and index (grid/block bcast/reduction) |
25 | std::vector<Expr*> insertRawThreadSynchronization( |
26 | const std::vector<Expr*>& exprs); |
27 | |
28 | } // namespace cuda |
29 | } // namespace fuser |
30 | } // namespace jit |
31 | } // namespace torch |
32 | |