1 | #include <ATen/core/dispatch/Dispatcher.h> |
2 | #include <torch/csrc/utils/schema_info.h> |
3 | |
4 | namespace torch { |
5 | namespace utils { |
6 | void SchemaInfo::addArgumentValue( |
7 | const std::string& name, |
8 | const at::IValue& value) { |
9 | c10::optional<int> index = schema_.argumentIndexWithName(name); |
10 | TORCH_INTERNAL_ASSERT( |
11 | index != c10::nullopt, "Schema has no argument named " , name); |
12 | value_map_[name] = value; |
13 | alias_maps_current_ = false; |
14 | } |
15 | |
16 | void SchemaInfo::addArgumentValues( |
17 | const std::vector<c10::optional<at::IValue>>& value_list) { |
18 | TORCH_INTERNAL_ASSERT( |
19 | value_list.size() <= schema_.arguments().size(), |
20 | "Schema does not have enough arguments for value list" ); |
21 | |
22 | for (size_t i = 0; i < value_list.size(); i++) { |
23 | if (value_list[i] != c10::nullopt) { |
24 | value_map_[schema_.arguments()[i].name()] = *(value_list[i]); |
25 | alias_maps_current_ = false; |
26 | } |
27 | } |
28 | } |
29 | |
30 | void SchemaInfo::addArgumentValues( |
31 | const std::unordered_map<std::string, at::IValue>& values) { |
32 | for (const auto& key_pair : values) { |
33 | addArgumentValue(key_pair.first, key_pair.second); |
34 | } |
35 | } |
36 | |
37 | bool SchemaInfo::hasInputArgumentNamed(const std::string& name) const { |
38 | return std::any_of( |
39 | schema_.arguments().begin(), |
40 | schema_.arguments().end(), |
41 | [&name](const c10::Argument& arg) { return arg.name() == name; }); |
42 | } |
43 | |
44 | bool SchemaInfo::is_mutable() { |
45 | for (size_t i = 0; i < schema_.arguments().size(); i++) { |
46 | if (is_mutable({c10::SchemaArgType::input, i})) { |
47 | return true; |
48 | } |
49 | } |
50 | return false; |
51 | } |
52 | |
53 | bool SchemaInfo::is_mutable(const c10::SchemaArgument& argument) { |
54 | TORCH_INTERNAL_ASSERT( |
55 | argument.index < schema_.getCorrectList(argument.type).size(), |
56 | "Invalid index for schema." ); |
57 | if (!alias_maps_current_) { |
58 | generateAliasMaps(); |
59 | } |
60 | static const std::vector<SchemaSpecialCasePair> training_ops = |
61 | getTrainingOps(); |
62 | const auto& correct_map = (argument.type == c10::SchemaArgType::input) |
63 | ? input_alias_map_ |
64 | : output_alias_map_; |
65 | // Note that the training_op checks depend on index because |
66 | // of cases where either running_mean or running_var alias another input |
67 | // argument causing its alias status to change. |
68 | return std::any_of( |
69 | correct_map[argument.index].begin(), |
70 | correct_map[argument.index].end(), |
71 | [this](size_t aliasing_index) { |
72 | const auto is_training_op = std::find_if( |
73 | training_ops.begin(), |
74 | training_ops.end(), |
75 | [this](const auto& training_op) { |
76 | return this->schema_ == training_op.first; |
77 | }); |
78 | |
79 | bool special_case = (is_training_op != training_ops.end()) && |
80 | is_training_op->second.count( |
81 | this->schema_.arguments()[aliasing_index].name()); |
82 | if (special_case) { |
83 | bool has_training = (hasInputArgumentNamed("training" ) && |
84 | !value_map_.count("training" )) || |
85 | (value_map_.count("training" ) && |
86 | value_map_.at("training" ).toBool()); |
87 | bool has_train = |
88 | (hasInputArgumentNamed("train" ) && !value_map_.count("train" )) || |
89 | (value_map_.count("train" ) && value_map_.at("train" ).toBool()); |
90 | bool has_use_input_stats = |
91 | (hasInputArgumentNamed("use_input_stats" ) && |
92 | !value_map_.count("use_input_stats" )) || |
93 | (value_map_.count("use_input_stats" ) && |
94 | value_map_.at("use_input_stats" ).toBool()); |
95 | return has_training || has_train || has_use_input_stats; |
96 | } else { |
97 | return this->schema_.is_mutable( |
98 | {c10::SchemaArgType::input, aliasing_index}); |
99 | } |
100 | }); |
101 | } |
102 | |
103 | bool SchemaInfo::has_argument(c10::string_view name) { |
104 | return schema_.argumentIndexWithName(name) != c10::nullopt; |
105 | } |
106 | |
107 | bool SchemaInfo::is_mutable(c10::string_view name) { |
108 | c10::optional<int> index = schema_.argumentIndexWithName(name); |
109 | TORCH_INTERNAL_ASSERT( |
110 | index != c10::nullopt, "Schema has no argument named " , name); |
111 | |
112 | return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)}); |
113 | } |
114 | |
115 | bool SchemaInfo::is_nondeterministic() const { |
116 | static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema( |
117 | "aten::dropout(Tensor input, float p, bool train) -> Tensor" ); |
118 | if (dropout_schema == schema_ && value_map_.count("train" ) && |
119 | !value_map_.at("train" ).toBool()) { |
120 | return false; |
121 | } |
122 | |
123 | #if defined C10_MOBILE |
124 | static const std::vector<c10::FunctionSchema> nondeterministic_ops = |
125 | getNonDeterministicOps(); |
126 | return std::any_of( |
127 | nondeterministic_ops.begin(), |
128 | nondeterministic_ops.end(), |
129 | [this](const c10 ::FunctionSchema& nondeterministic_op) { |
130 | return nondeterministic_op == this->schema_; |
131 | }); |
132 | #else |
133 | const auto& op = c10::Dispatcher::singleton().findOp( |
134 | c10::OperatorName(schema_.name(), schema_.overload_name())); |
135 | return op && op->hasTag(at::Tag::nondeterministic_seeded); |
136 | #endif |
137 | } |
138 | |
139 | bool SchemaInfo::may_alias( |
140 | const c10::SchemaArgument& lhs, |
141 | const c10::SchemaArgument& rhs) { |
142 | bool basic_check = schema_.may_alias(lhs, rhs); |
143 | if (basic_check) { |
144 | return true; |
145 | } |
146 | c10::optional<c10::AliasTypeSet> lhsAliasTypeSet = |
147 | schema_.mapTypeToAliasTypeSet( |
148 | schema_.getCorrectList(lhs.type)[lhs.index].type()); |
149 | c10::optional<c10::AliasTypeSet> rhsAliasTypeSet = |
150 | schema_.mapTypeToAliasTypeSet( |
151 | schema_.getCorrectList(rhs.type)[rhs.index].type()); |
152 | bool types_can_alias = |
153 | schema_.canAliasTypeSetsAlias(lhsAliasTypeSet, rhsAliasTypeSet); |
154 | if (!types_can_alias) { |
155 | return false; |
156 | } |
157 | |
158 | if (!alias_maps_current_) { |
159 | generateAliasMaps(); |
160 | } |
161 | bool wildcard_alias_check = |
162 | wildcardSet().count(lhs) && wildcardSet().count(rhs); |
163 | if (wildcard_alias_check) { |
164 | return true; |
165 | } |
166 | |
167 | if (lhs.type == c10::SchemaArgType::input && |
168 | rhs.type == c10::SchemaArgType::input) { |
169 | return input_alias_map_[lhs.index].count(rhs.index); |
170 | } else if ( |
171 | lhs.type == c10::SchemaArgType::output && |
172 | rhs.type == c10::SchemaArgType::output) { |
173 | for (size_t lhs_alias_input : output_alias_map_[lhs.index]) { |
174 | if (output_alias_map_[rhs.index].count(lhs_alias_input)) { |
175 | return true; |
176 | } |
177 | } |
178 | return false; |
179 | } else if (lhs.type == c10::SchemaArgType::output) { |
180 | return output_alias_map_[lhs.index].count(rhs.index); |
181 | } else { |
182 | return output_alias_map_[rhs.index].count(lhs.index); |
183 | } |
184 | } |
185 | |
186 | bool SchemaInfo::may_contain_alias( |
187 | const c10::SchemaArgument& lhs, |
188 | const c10::SchemaArgument& rhs, |
189 | bool bidirectional) { |
190 | bool basic_check = schema_.may_contain_alias(lhs, rhs) || may_alias(lhs, rhs); |
191 | if (basic_check) { |
192 | return true; |
193 | } |
194 | if (!alias_maps_current_) { |
195 | generateAliasMaps(); |
196 | } |
197 | if (bidirectional) { |
198 | return mayContainAliasImpl(lhs, rhs) || mayContainAliasImpl(rhs, lhs); |
199 | } else { |
200 | return mayContainAliasImpl(lhs, rhs); |
201 | } |
202 | } |
203 | |
204 | bool SchemaInfo::mayContainAliasImpl( |
205 | const c10::SchemaArgument& lhs, |
206 | const c10::SchemaArgument& rhs) { |
207 | c10::optional<c10::AliasTypeSet> lhsContainedAliasTypeSet = |
208 | schema_.getAliasTypeSetContainedTypes(schema_.mapTypeToAliasTypeSet( |
209 | schema_.getCorrectList(lhs.type)[lhs.index].type())); |
210 | c10::optional<c10::AliasTypeSet> rhsAliasTypeSet = |
211 | schema_.mapTypeToAliasTypeSet( |
212 | schema_.getCorrectList(rhs.type)[rhs.index].type()); |
213 | bool types_can_alias = |
214 | schema_.canAliasTypeSetsAlias(lhsContainedAliasTypeSet, rhsAliasTypeSet); |
215 | return types_can_alias && containerSet().count(lhs) && |
216 | wildcardSet().count(rhs); |
217 | } |
218 | |
219 | void SchemaInfo::ensureConservativity( |
220 | const std::unordered_set<at::Symbol>& duplicates, |
221 | const std::vector<c10::Argument>& arguments_list, |
222 | c10::SchemaArgType type) { |
223 | for (size_t i = 0; i < arguments_list.size(); i++) { |
224 | if (arguments_list[i].alias_info()) { |
225 | for (const auto& set : arguments_list[i].alias_info()->afterSets()) { |
226 | if (duplicates.count(set)) { |
227 | wildcard_set_.insert({type, i}); |
228 | } |
229 | } |
230 | } |
231 | } |
232 | } |
233 | |
234 | std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() { |
235 | // This list of nondeterministic ops is copied from JIT ir.cpp. |
236 | static const std::vector<std::string> nondeterministic_op_strings = { |
237 | "aten::dropout(Tensor input, float p, bool train) -> Tensor" , |
238 | "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)" , |
239 | "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor" , |
240 | "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor" , |
241 | "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor" , |
242 | "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor" , |
243 | "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)" , |
244 | "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor" , |
245 | "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor" , |
246 | "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor" , |
247 | "aten::poisson(Tensor self, Generator? generator) -> Tensor" , |
248 | "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor" , |
249 | "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor" , |
250 | "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor" , |
251 | "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
252 | "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
253 | "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
254 | "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
255 | "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
256 | "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
257 | "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" , |
258 | "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor" , |
259 | "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor" }; |
260 | |
261 | std::vector<c10::FunctionSchema> nondeterministic_ops; |
262 | nondeterministic_ops.reserve(nondeterministic_op_strings.size()); |
263 | for (const std::string& signature : nondeterministic_op_strings) { |
264 | nondeterministic_ops.emplace_back(torch::jit::parseSchema(signature)); |
265 | } |
266 | |
267 | return nondeterministic_ops; |
268 | } |
269 | |
270 | std::vector<SchemaSpecialCasePair> SchemaInfo::getTrainingOps() { |
271 | // This is a list of pairs of ops to sets of strings |
272 | // where the a boolean variable (either "training", |
273 | // "train" or "use_input_stats") affects the mutability |
274 | // of the unorderered set of strings. |
275 | static const std::vector<std::pair<std::string, std::unordered_set<std::string>>> training_op_pairs = |
276 | {{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor" , |
277 | {"running_mean" , "running_var" }}, |
278 | {"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor" , |
279 | {"running_mean" , "running_var" }}, |
280 | {"aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)" , |
281 | {"running_mean" , "running_var" }}, |
282 | {"aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)" , |
283 | {"running_mean" , "running_var" }}, |
284 | {"aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)" , |
285 | {"running_mean" , "running_var" }}, |
286 | {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)" , |
287 | {"running_mean" , "running_var" }}, |
288 | {"aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))" , |
289 | {"running_mean" , "running_var" }}, |
290 | {"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor" , |
291 | {"noise" }}, |
292 | {"aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)" , |
293 | {"noise" }}, |
294 | {"rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)" , |
295 | {"noise" }}}; |
296 | |
297 | std::vector<SchemaSpecialCasePair> training_ops; |
298 | training_ops.reserve(training_op_pairs.size()); |
299 | for (const auto& signature : training_op_pairs) { |
300 | training_ops.emplace_back( |
301 | torch::jit::parseSchema(signature.first), signature.second); |
302 | } |
303 | |
304 | return training_ops; |
305 | } |
306 | |
307 | void SchemaInfo::initSchemaInfo() { |
308 | if (has_init_) { |
309 | return; |
310 | } |
311 | has_init_ = true; |
312 | |
313 | std::unordered_set<at::Symbol> duplicates; |
314 | auto init_schema_arguments = [this, &duplicates]( |
315 | const std::vector<c10::Argument>& |
316 | arguments_list, |
317 | c10::SchemaArgType type) { |
318 | std::unordered_set<at::Symbol> seen; |
319 | for (size_t i = 0; i < arguments_list.size(); i++) { |
320 | const c10::Argument& argument = arguments_list[i]; |
321 | if (argument.alias_info()) { |
322 | if (argument.alias_info()->isWildcardAfter()) { |
323 | wildcard_set_.insert({type, i}); |
324 | } else { |
325 | // This check is to ensure that the FunctionSchema will accurately |
326 | // be represented when calling may_alias and may_contain_alias |
327 | // on schemas with more than one argument within arguments_list that |
328 | // shares an alias set. |
329 | for (const auto& set : argument.alias_info()->afterSets()) { |
330 | if (seen.count(set)) { |
331 | TORCH_WARN( |
332 | set.toQualString(), |
333 | " appears twice in same argument list which will make aliasing checks more conservative." ); |
334 | duplicates.insert(set); |
335 | } else { |
336 | seen.insert(set); |
337 | } |
338 | } |
339 | } |
340 | } |
341 | c10::optional<c10::AliasTypeSet> contained_types = |
342 | schema_.getAliasTypeSetContainedTypes( |
343 | schema_.mapTypeToAliasTypeSet(argument.type())); |
344 | if (contained_types && !contained_types->empty()) { |
345 | container_set_.insert({type, i}); |
346 | } |
347 | } |
348 | }; |
349 | |
350 | init_schema_arguments(schema_.arguments(), c10::SchemaArgType::input); |
351 | init_schema_arguments(schema_.returns(), c10::SchemaArgType::output); |
352 | ensureConservativity( |
353 | duplicates, schema_.arguments(), c10::SchemaArgType::input); |
354 | ensureConservativity( |
355 | duplicates, schema_.returns(), c10::SchemaArgType::output); |
356 | } |
357 | |
358 | const std::unordered_set<c10::SchemaArgument>& SchemaInfo::wildcardSet() { |
359 | initSchemaInfo(); |
360 | return wildcard_set_; |
361 | } |
362 | |
363 | const std::unordered_set<c10::SchemaArgument>& SchemaInfo::containerSet() { |
364 | initSchemaInfo(); |
365 | return container_set_; |
366 | } |
367 | |
368 | void SchemaInfo::generateAliasMaps() { |
369 | initSchemaInfo(); |
370 | |
371 | alias_maps_current_ = true; |
372 | input_alias_map_ = std::vector<std::unordered_set<size_t>>( |
373 | schema_.arguments().size(), std::unordered_set<size_t>()); |
374 | output_alias_map_ = std::vector<std::unordered_set<size_t>>( |
375 | schema_.returns().size(), std::unordered_set<size_t>()); |
376 | |
377 | // Fills input_alias_map_ |
378 | for (size_t i = 0; i < schema_.arguments().size(); i++) { |
379 | for (size_t j = i; j < schema_.arguments().size(); j++) { |
380 | if (i == j) { |
381 | input_alias_map_[i].insert(i); |
382 | } else if ( |
383 | value_map_.count(schema_.arguments()[i].name()) && |
384 | value_map_.count(schema_.arguments()[j].name())) { |
385 | if (value_map_[schema_.arguments()[i].name()].isAliasOf( |
386 | value_map_[schema_.arguments()[j].name()])) { |
387 | input_alias_map_[i].insert(j); |
388 | input_alias_map_[j].insert(i); |
389 | if (wildcard_set_.count({c10::SchemaArgType::input, i})) { |
390 | wildcard_set_.insert({c10::SchemaArgType::input, j}); |
391 | } else if (wildcard_set_.count({c10::SchemaArgType::input, j})) { |
392 | wildcard_set_.insert({c10::SchemaArgType::input, i}); |
393 | } |
394 | } |
395 | } |
396 | } |
397 | } |
398 | |
399 | // Fills wildcard_set with container created wildcards. |
400 | // For instance, given the schema: |
401 | // test(Tensor a, Tensor(*) b, Tensor[] c) -> Tensor |
402 | // where value(a) is contained in value(c), then a will be added to the |
403 | // wildcard set where it can now alias b. |
404 | for (size_t i = 0; i < schema_.arguments().size(); i++) { |
405 | for (size_t j = 0; j < schema_.arguments().size(); j++) { |
406 | // if they are already aliasing, there is no way one contains the other |
407 | if (!input_alias_map_[i].count(j) && |
408 | value_map_.count(schema_.arguments()[i].name()) && |
409 | value_map_.count(schema_.arguments()[j].name())) { |
410 | c10::IValue::HashAliasedIValues subValues; |
411 | value_map_[schema_.arguments()[i].name()].getSubValues(subValues); |
412 | if (subValues.count(value_map_[schema_.arguments()[j].name()])) { |
413 | wildcard_set_.insert({c10::SchemaArgType::input, j}); |
414 | } |
415 | } |
416 | } |
417 | } |
418 | |
419 | // Fills output_alias_map_ |
420 | for (size_t i = 0; i < schema_.arguments().size(); i++) { |
421 | for (size_t j = 0; j < schema_.returns().size(); j++) { |
422 | if (schema_.may_alias( |
423 | {c10::SchemaArgType::input, i}, |
424 | {c10::SchemaArgType::output, j})) { |
425 | if (wildcard_set_.count({c10::SchemaArgType::input, i})) { |
426 | wildcard_set_.insert({c10::SchemaArgType::output, j}); |
427 | } |
428 | output_alias_map_[j].insert( |
429 | input_alias_map_[i].begin(), input_alias_map_[i].end()); |
430 | } |
431 | } |
432 | } |
433 | } |
434 | |
435 | } // namespace utils |
436 | } // namespace torch |
437 | |