1 | #pragma once |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <torch/csrc/jit/ir/alias_analysis.h> |
5 | #include <torch/csrc/jit/ir/ir_views.h> |
6 | #include <torch/csrc/jit/jit_log.h> |
7 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
8 | #include <torch/csrc/jit/passes/peephole.h> |
9 | #include <torch/csrc/jit/passes/peephole_list_idioms.h> |
10 | #include <torch/csrc/jit/runtime/graph_executor.h> |
11 | #include <torch/csrc/utils/memory.h> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | // Refine from Value of type List -> len of list |
17 | // If a refinement mapping of List Value * -> len is present in a block |
18 | // the list is guaranteed to be that length |
19 | // TODO: vector may be faster |
20 | using ListRefinement = std::unordered_map<Value*, int64_t>; |
21 | |
22 | TORCH_API ListRefinement |
23 | intersectRefinements(const ListRefinement& ref1, const ListRefinement& ref2); |
24 | |
25 | TORCH_API ListRefinement |
26 | unionRefinements(const ListRefinement& ref1, const ListRefinement& ref2); |
27 | |
28 | // Represents the refinement information that can be carried on a boolean |
29 | struct BooleanRefinementMapping { |
30 | BooleanRefinementMapping( |
31 | ListRefinement true_refine, |
32 | ListRefinement false_refine) |
33 | : true_refine_(std::move(true_refine)), |
34 | false_refine_(std::move(false_refine)){}; |
35 | BooleanRefinementMapping() = default; // empty |
36 | |
37 | static BooleanRefinementMapping FalseRefinements( |
38 | ListRefinement false_refine) { |
39 | return BooleanRefinementMapping({}, std::move(false_refine)); |
40 | } |
41 | |
42 | static BooleanRefinementMapping TrueRefinements(ListRefinement true_refine) { |
43 | return BooleanRefinementMapping(std::move(true_refine), {}); |
44 | } |
45 | |
46 | BooleanRefinementMapping intersectBooleanRefinementMapping( |
47 | BooleanRefinementMapping& other) { |
48 | return BooleanRefinementMapping( |
49 | intersectRefinements(true_refine_, other.true_refine()), |
50 | intersectRefinements(false_refine_, other.false_refine())); |
51 | } |
52 | |
53 | ListRefinement& true_refine() { |
54 | return true_refine_; |
55 | } |
56 | |
57 | ListRefinement& false_refine() { |
58 | return false_refine_; |
59 | } |
60 | |
61 | private: |
62 | ListRefinement true_refine_; |
63 | ListRefinement false_refine_; |
64 | }; |
65 | |
66 | TORCH_API void joinIfRefinements( |
67 | Node* if_node, |
68 | std::unordered_set<Block*>& throwing_blocks, |
69 | ListRefinement& curr_block_refinements, |
70 | ListRefinement& true_block_refinements, |
71 | ListRefinement& false_block_refinements, |
72 | std::unordered_map<Value*, BooleanRefinementMapping>& info); |
73 | |
74 | // handles adding blocks to throwing blocks and propagating refinements via |
75 | // boolean comparisons |
76 | TORCH_API bool handleCommonRefinentOperators( |
77 | Node* n, |
78 | std::unordered_set<Block*>& throwing_blocks, |
79 | std::unordered_map<Value*, BooleanRefinementMapping>& info); |
80 | |
81 | } // namespace jit |
82 | } // namespace torch |
83 | |