1#pragma once
2
3#include <c10/macros/Export.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/runtime/profiling_record.h>
6
7#include <fusion.h>
8
9/*
10 * This file handles Parsing PyTorch jit ir;
11 *
12 * It is used in two places:
13 * 1. When partitioning PyTorch jit ir to create prim::CudaFusionGroup, each
14 * node is queried by `isNodeParsible` to determine whether the node could
15 * be handled by the fuser (whether a given PyTorch jit operator should be
16 * merged);
17 * 2. lowering PyTorch jit ir to CUDA codegen ir.
18 * creates a `Fusion` by traversing a PyTorch jit graph.
19 *
20 * TODO: we could consider exposing API to allow custom registration of parsing
21 * rules for a given PyTorch jit operator.
22 */
23
24namespace torch {
25namespace jit {
26namespace fuser {
27namespace cuda {
28
29constexpr int kPwThreadX = 128;
30constexpr int kFcdReductionThreadX = 128;
31constexpr int kNonFcdReductionThreadX = 32;
32constexpr int kNonFcdReductionThreadY = 32;
33
34TORCH_CUDA_CU_API bool hasReductionNode(const Block* block);
35TORCH_CUDA_CU_API bool isReductionToSizeNode(const Node* node);
36TORCH_CUDA_CU_API bool isReductionNode(const Node* node);
37
38TORCH_CUDA_CU_API bool hasNormalizationNode(const Block* block);
39TORCH_CUDA_CU_API bool isNormalizationNode(const Node* node);
40
41TORCH_CUDA_CU_API bool isElementWiseNode(const Node* node);
42
43// returns whether or not a parsing function exists for the given node type.
44TORCH_CUDA_CU_API bool isNodeParsible(const Node* node);
45TORCH_CUDA_CU_API bool shouldProfileNode(const Node* node);
46
47TORCH_CUDA_CU_API bool skipNodeKind(const std::string& symbol_str, bool flip);
48
49void InsertProfileNodes(ProfilingRecord* pr);
50
51// lowers PyTorch jit graph to `Fusion`.
52TORCH_CUDA_CU_API std::unique_ptr<Fusion> parseJitIR(
53 const std::shared_ptr<Graph>& graph);
54
55} // namespace cuda
56} // namespace fuser
57} // namespace jit
58} // namespace torch
59