1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/runtime/profiling_record.h> |
6 | |
7 | #include <fusion.h> |
8 | |
9 | /* |
10 | * This file handles Parsing PyTorch jit ir; |
11 | * |
12 | * It is used in two places: |
13 | * 1. When partitioning PyTorch jit ir to create prim::CudaFusionGroup, each |
14 | * node is queried by `isNodeParsible` to determine whether the node could |
15 | * be handled by the fuser (whether a given PyTorch jit operator should be |
16 | * merged); |
17 | * 2. lowering PyTorch jit ir to CUDA codegen ir. |
18 | * creates a `Fusion` by traversing a PyTorch jit graph. |
19 | * |
20 | * TODO: we could consider exposing API to allow custom registration of parsing |
21 | * rules for a given PyTorch jit operator. |
22 | */ |
23 | |
24 | namespace torch { |
25 | namespace jit { |
26 | namespace fuser { |
27 | namespace cuda { |
28 | |
29 | constexpr int kPwThreadX = 128; |
30 | constexpr int kFcdReductionThreadX = 128; |
31 | constexpr int kNonFcdReductionThreadX = 32; |
32 | constexpr int kNonFcdReductionThreadY = 32; |
33 | |
34 | TORCH_CUDA_CU_API bool hasReductionNode(const Block* block); |
35 | TORCH_CUDA_CU_API bool isReductionToSizeNode(const Node* node); |
36 | TORCH_CUDA_CU_API bool isReductionNode(const Node* node); |
37 | |
38 | TORCH_CUDA_CU_API bool hasNormalizationNode(const Block* block); |
39 | TORCH_CUDA_CU_API bool isNormalizationNode(const Node* node); |
40 | |
41 | TORCH_CUDA_CU_API bool isElementWiseNode(const Node* node); |
42 | |
43 | // returns whether or not a parsing function exists for the given node type. |
44 | TORCH_CUDA_CU_API bool isNodeParsible(const Node* node); |
45 | TORCH_CUDA_CU_API bool shouldProfileNode(const Node* node); |
46 | |
47 | TORCH_CUDA_CU_API bool skipNodeKind(const std::string& symbol_str, bool flip); |
48 | |
49 | void InsertProfileNodes(ProfilingRecord* pr); |
50 | |
51 | // lowers PyTorch jit graph to `Fusion`. |
52 | TORCH_CUDA_CU_API std::unique_ptr<Fusion> parseJitIR( |
53 | const std::shared_ptr<Graph>& graph); |
54 | |
55 | } // namespace cuda |
56 | } // namespace fuser |
57 | } // namespace jit |
58 | } // namespace torch |
59 | |