1#pragma once
2
3#include <ir_all_nodes.h>
4#include <parallel_type_bitmap.h>
5
6#include <unordered_map>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13class SyncMap {
14 public:
15 std::string toString() const;
16
17 //! Validates all tensors are consistently parallelized. Basically,
18 //! when a producer axis is threaded, either with threadIdx or
19 //! blockIdx, there must be a mapped consumer axis with the
20 //! same ParallelType with some exceptions.
21 //!
22 //! This function assumes Loop and Parallel ComputeAtMaps are already
23 //! built as they are used to validate consistency.
24 //!
25 //! Fills needs_raw_sync with output TVs if they need a raw sync if on smem or
26 //! gmem. The second entry in this map is the parallel dimensions being
27 //! communicated across.
28 void build(Fusion* fusion);
29
30 ParallelTypeBitmap needsRawSync(TensorView* tv) const {
31 auto it = needs_raw_sync_.find(tv);
32 if (it != needs_raw_sync_.end()) {
33 return it->second;
34 }
35 return ParallelTypeBitmap();
36 }
37
38 private:
39 std::unordered_map<TensorView*, ParallelTypeBitmap> needs_raw_sync_;
40};
41
42} // namespace cuda
43} // namespace fuser
44} // namespace jit
45} // namespace torch
46