1#pragma once
2
3#include <ATen/core/jit_type.h>
4#include <torch/csrc/jit/ir/alias_analysis.h>
5#include <torch/csrc/jit/ir/ir_views.h>
6#include <torch/csrc/jit/jit_log.h>
7#include <torch/csrc/jit/passes/dead_code_elimination.h>
8#include <torch/csrc/jit/passes/peephole.h>
9#include <torch/csrc/jit/passes/peephole_list_idioms.h>
10#include <torch/csrc/jit/runtime/graph_executor.h>
11#include <torch/csrc/utils/memory.h>
12
13namespace torch {
14namespace jit {
15
16// Refine from Value of type List -> len of list
17// If a refinement mapping of List Value * -> len is present in a block
18// the list is guaranteed to be that length
19// TODO: vector may be faster
20using ListRefinement = std::unordered_map<Value*, int64_t>;
21
22TORCH_API ListRefinement
23intersectRefinements(const ListRefinement& ref1, const ListRefinement& ref2);
24
25TORCH_API ListRefinement
26unionRefinements(const ListRefinement& ref1, const ListRefinement& ref2);
27
28// Represents the refinement information that can be carried on a boolean
29struct BooleanRefinementMapping {
30 BooleanRefinementMapping(
31 ListRefinement true_refine,
32 ListRefinement false_refine)
33 : true_refine_(std::move(true_refine)),
34 false_refine_(std::move(false_refine)){};
35 BooleanRefinementMapping() = default; // empty
36
37 static BooleanRefinementMapping FalseRefinements(
38 ListRefinement false_refine) {
39 return BooleanRefinementMapping({}, std::move(false_refine));
40 }
41
42 static BooleanRefinementMapping TrueRefinements(ListRefinement true_refine) {
43 return BooleanRefinementMapping(std::move(true_refine), {});
44 }
45
46 BooleanRefinementMapping intersectBooleanRefinementMapping(
47 BooleanRefinementMapping& other) {
48 return BooleanRefinementMapping(
49 intersectRefinements(true_refine_, other.true_refine()),
50 intersectRefinements(false_refine_, other.false_refine()));
51 }
52
53 ListRefinement& true_refine() {
54 return true_refine_;
55 }
56
57 ListRefinement& false_refine() {
58 return false_refine_;
59 }
60
61 private:
62 ListRefinement true_refine_;
63 ListRefinement false_refine_;
64};
65
66TORCH_API void joinIfRefinements(
67 Node* if_node,
68 std::unordered_set<Block*>& throwing_blocks,
69 ListRefinement& curr_block_refinements,
70 ListRefinement& true_block_refinements,
71 ListRefinement& false_block_refinements,
72 std::unordered_map<Value*, BooleanRefinementMapping>& info);
73
74// handles adding blocks to throwing blocks and propagating refinements via
75// boolean comparisons
76TORCH_API bool handleCommonRefinentOperators(
77 Node* n,
78 std::unordered_set<Block*>& throwing_blocks,
79 std::unordered_map<Value*, BooleanRefinementMapping>& info);
80
81} // namespace jit
82} // namespace torch
83