1#include <cstdlib>
2#include <iomanip>
3#include <sstream>
4#include <string>
5#include <utility>
6#include <vector>
7
8#include <ATen/core/function.h>
9#include <c10/util/Exception.h>
10#include <c10/util/StringUtil.h>
11#include <torch/csrc/jit/api/function_impl.h>
12#include <torch/csrc/jit/jit_opt_limit.h>
13
14namespace torch {
15namespace jit {
16
17std::unordered_map<std::string, int64_t>& passes_to_current_counter() {
18 static std::unordered_map<std::string, int64_t> passes_to_current_counter;
19 return passes_to_current_counter;
20}
21
22static int parseOptLimit(const std::string& opt_limit) {
23 try {
24 int64_t n = c10::stoi(opt_limit);
25 return n;
26 } catch (...) {
27 return -1;
28 }
29}
30
31static std::unordered_map<std::string, int64_t> parseJITOptLimitOption(
32 const char* option) {
33 std::stringstream in_ss;
34 if (option) {
35 in_ss << option;
36 }
37 std::unordered_map<std::string, int64_t> passes_to_opt_limits;
38 std::string line;
39 while (std::getline(in_ss, line, ':')) {
40 if (line.empty()) {
41 continue;
42 }
43 auto index_at = line.find_last_of('=');
44 auto pass_name = line.substr(0, index_at);
45 pass_name = c10::detail::ExcludeFileExtension(pass_name);
46 auto opt_limit = parseOptLimit(line.substr(index_at + 1));
47 passes_to_opt_limits.insert({pass_name, opt_limit});
48 }
49
50 return passes_to_opt_limits;
51}
52
53bool opt_limit(const char* pass_name) {
54 static const char* opt_limit = std::getenv("PYTORCH_JIT_OPT_LIMIT");
55 // if nothing is provided, let's allow everything
56 if (!opt_limit) {
57 return true;
58 }
59
60 static const std::unordered_map<std::string, int64_t> passes_to_opt_limits =
61 parseJITOptLimitOption(opt_limit);
62 std::string pass{pass_name};
63 pass = c10::detail::StripBasename(pass);
64 pass = c10::detail::ExcludeFileExtension(pass);
65
66 auto opt_limit_it = passes_to_opt_limits.find(pass);
67 if (opt_limit_it == passes_to_opt_limits.end()) {
68 return true;
69 }
70
71 auto current_count_it = passes_to_current_counter().find(pass);
72 if (current_count_it == passes_to_current_counter().end()) {
73 passes_to_current_counter().insert({pass, 0});
74 }
75
76 current_count_it = passes_to_current_counter().find(pass);
77 if (current_count_it->second >= opt_limit_it->second) {
78 return false;
79 }
80
81 current_count_it->second++;
82 return true;
83}
84
85} // namespace jit
86} // namespace torch
87