1 | #include <cstdlib> |
2 | #include <iomanip> |
3 | #include <sstream> |
4 | #include <string> |
5 | #include <utility> |
6 | #include <vector> |
7 | |
8 | #include <ATen/core/function.h> |
9 | #include <c10/util/Exception.h> |
10 | #include <c10/util/StringUtil.h> |
11 | #include <torch/csrc/jit/api/function_impl.h> |
12 | #include <torch/csrc/jit/jit_opt_limit.h> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | std::unordered_map<std::string, int64_t>& passes_to_current_counter() { |
18 | static std::unordered_map<std::string, int64_t> passes_to_current_counter; |
19 | return passes_to_current_counter; |
20 | } |
21 | |
22 | static int parseOptLimit(const std::string& opt_limit) { |
23 | try { |
24 | int64_t n = c10::stoi(opt_limit); |
25 | return n; |
26 | } catch (...) { |
27 | return -1; |
28 | } |
29 | } |
30 | |
31 | static std::unordered_map<std::string, int64_t> parseJITOptLimitOption( |
32 | const char* option) { |
33 | std::stringstream in_ss; |
34 | if (option) { |
35 | in_ss << option; |
36 | } |
37 | std::unordered_map<std::string, int64_t> passes_to_opt_limits; |
38 | std::string line; |
39 | while (std::getline(in_ss, line, ':')) { |
40 | if (line.empty()) { |
41 | continue; |
42 | } |
43 | auto index_at = line.find_last_of('='); |
44 | auto pass_name = line.substr(0, index_at); |
45 | pass_name = c10::detail::ExcludeFileExtension(pass_name); |
46 | auto opt_limit = parseOptLimit(line.substr(index_at + 1)); |
47 | passes_to_opt_limits.insert({pass_name, opt_limit}); |
48 | } |
49 | |
50 | return passes_to_opt_limits; |
51 | } |
52 | |
53 | bool opt_limit(const char* pass_name) { |
54 | static const char* opt_limit = std::getenv("PYTORCH_JIT_OPT_LIMIT" ); |
55 | // if nothing is provided, let's allow everything |
56 | if (!opt_limit) { |
57 | return true; |
58 | } |
59 | |
60 | static const std::unordered_map<std::string, int64_t> passes_to_opt_limits = |
61 | parseJITOptLimitOption(opt_limit); |
62 | std::string pass{pass_name}; |
63 | pass = c10::detail::StripBasename(pass); |
64 | pass = c10::detail::ExcludeFileExtension(pass); |
65 | |
66 | auto opt_limit_it = passes_to_opt_limits.find(pass); |
67 | if (opt_limit_it == passes_to_opt_limits.end()) { |
68 | return true; |
69 | } |
70 | |
71 | auto current_count_it = passes_to_current_counter().find(pass); |
72 | if (current_count_it == passes_to_current_counter().end()) { |
73 | passes_to_current_counter().insert({pass, 0}); |
74 | } |
75 | |
76 | current_count_it = passes_to_current_counter().find(pass); |
77 | if (current_count_it->second >= opt_limit_it->second) { |
78 | return false; |
79 | } |
80 | |
81 | current_count_it->second++; |
82 | return true; |
83 | } |
84 | |
85 | } // namespace jit |
86 | } // namespace torch |
87 | |