1#include <ATen/core/dispatch/Dispatcher.h>
2#include <torch/csrc/utils/schema_info.h>
3
4namespace torch {
5namespace utils {
6void SchemaInfo::addArgumentValue(
7 const std::string& name,
8 const at::IValue& value) {
9 c10::optional<int> index = schema_.argumentIndexWithName(name);
10 TORCH_INTERNAL_ASSERT(
11 index != c10::nullopt, "Schema has no argument named ", name);
12 value_map_[name] = value;
13 alias_maps_current_ = false;
14}
15
16void SchemaInfo::addArgumentValues(
17 const std::vector<c10::optional<at::IValue>>& value_list) {
18 TORCH_INTERNAL_ASSERT(
19 value_list.size() <= schema_.arguments().size(),
20 "Schema does not have enough arguments for value list");
21
22 for (size_t i = 0; i < value_list.size(); i++) {
23 if (value_list[i] != c10::nullopt) {
24 value_map_[schema_.arguments()[i].name()] = *(value_list[i]);
25 alias_maps_current_ = false;
26 }
27 }
28}
29
30void SchemaInfo::addArgumentValues(
31 const std::unordered_map<std::string, at::IValue>& values) {
32 for (const auto& key_pair : values) {
33 addArgumentValue(key_pair.first, key_pair.second);
34 }
35}
36
37bool SchemaInfo::hasInputArgumentNamed(const std::string& name) const {
38 return std::any_of(
39 schema_.arguments().begin(),
40 schema_.arguments().end(),
41 [&name](const c10::Argument& arg) { return arg.name() == name; });
42}
43
44bool SchemaInfo::is_mutable() {
45 for (size_t i = 0; i < schema_.arguments().size(); i++) {
46 if (is_mutable({c10::SchemaArgType::input, i})) {
47 return true;
48 }
49 }
50 return false;
51}
52
53bool SchemaInfo::is_mutable(const c10::SchemaArgument& argument) {
54 TORCH_INTERNAL_ASSERT(
55 argument.index < schema_.getCorrectList(argument.type).size(),
56 "Invalid index for schema.");
57 if (!alias_maps_current_) {
58 generateAliasMaps();
59 }
60 static const std::vector<SchemaSpecialCasePair> training_ops =
61 getTrainingOps();
62 const auto& correct_map = (argument.type == c10::SchemaArgType::input)
63 ? input_alias_map_
64 : output_alias_map_;
65 // Note that the training_op checks depend on index because
66 // of cases where either running_mean or running_var alias another input
67 // argument causing its alias status to change.
68 return std::any_of(
69 correct_map[argument.index].begin(),
70 correct_map[argument.index].end(),
71 [this](size_t aliasing_index) {
72 const auto is_training_op = std::find_if(
73 training_ops.begin(),
74 training_ops.end(),
75 [this](const auto& training_op) {
76 return this->schema_ == training_op.first;
77 });
78
79 bool special_case = (is_training_op != training_ops.end()) &&
80 is_training_op->second.count(
81 this->schema_.arguments()[aliasing_index].name());
82 if (special_case) {
83 bool has_training = (hasInputArgumentNamed("training") &&
84 !value_map_.count("training")) ||
85 (value_map_.count("training") &&
86 value_map_.at("training").toBool());
87 bool has_train =
88 (hasInputArgumentNamed("train") && !value_map_.count("train")) ||
89 (value_map_.count("train") && value_map_.at("train").toBool());
90 bool has_use_input_stats =
91 (hasInputArgumentNamed("use_input_stats") &&
92 !value_map_.count("use_input_stats")) ||
93 (value_map_.count("use_input_stats") &&
94 value_map_.at("use_input_stats").toBool());
95 return has_training || has_train || has_use_input_stats;
96 } else {
97 return this->schema_.is_mutable(
98 {c10::SchemaArgType::input, aliasing_index});
99 }
100 });
101}
102
103bool SchemaInfo::has_argument(c10::string_view name) {
104 return schema_.argumentIndexWithName(name) != c10::nullopt;
105}
106
107bool SchemaInfo::is_mutable(c10::string_view name) {
108 c10::optional<int> index = schema_.argumentIndexWithName(name);
109 TORCH_INTERNAL_ASSERT(
110 index != c10::nullopt, "Schema has no argument named ", name);
111
112 return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
113}
114
115bool SchemaInfo::is_nondeterministic() const {
116 static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema(
117 "aten::dropout(Tensor input, float p, bool train) -> Tensor");
118 if (dropout_schema == schema_ && value_map_.count("train") &&
119 !value_map_.at("train").toBool()) {
120 return false;
121 }
122
123#if defined C10_MOBILE
124 static const std::vector<c10::FunctionSchema> nondeterministic_ops =
125 getNonDeterministicOps();
126 return std::any_of(
127 nondeterministic_ops.begin(),
128 nondeterministic_ops.end(),
129 [this](const c10 ::FunctionSchema& nondeterministic_op) {
130 return nondeterministic_op == this->schema_;
131 });
132#else
133 const auto& op = c10::Dispatcher::singleton().findOp(
134 c10::OperatorName(schema_.name(), schema_.overload_name()));
135 return op && op->hasTag(at::Tag::nondeterministic_seeded);
136#endif
137}
138
139bool SchemaInfo::may_alias(
140 const c10::SchemaArgument& lhs,
141 const c10::SchemaArgument& rhs) {
142 bool basic_check = schema_.may_alias(lhs, rhs);
143 if (basic_check) {
144 return true;
145 }
146 c10::optional<c10::AliasTypeSet> lhsAliasTypeSet =
147 schema_.mapTypeToAliasTypeSet(
148 schema_.getCorrectList(lhs.type)[lhs.index].type());
149 c10::optional<c10::AliasTypeSet> rhsAliasTypeSet =
150 schema_.mapTypeToAliasTypeSet(
151 schema_.getCorrectList(rhs.type)[rhs.index].type());
152 bool types_can_alias =
153 schema_.canAliasTypeSetsAlias(lhsAliasTypeSet, rhsAliasTypeSet);
154 if (!types_can_alias) {
155 return false;
156 }
157
158 if (!alias_maps_current_) {
159 generateAliasMaps();
160 }
161 bool wildcard_alias_check =
162 wildcardSet().count(lhs) && wildcardSet().count(rhs);
163 if (wildcard_alias_check) {
164 return true;
165 }
166
167 if (lhs.type == c10::SchemaArgType::input &&
168 rhs.type == c10::SchemaArgType::input) {
169 return input_alias_map_[lhs.index].count(rhs.index);
170 } else if (
171 lhs.type == c10::SchemaArgType::output &&
172 rhs.type == c10::SchemaArgType::output) {
173 for (size_t lhs_alias_input : output_alias_map_[lhs.index]) {
174 if (output_alias_map_[rhs.index].count(lhs_alias_input)) {
175 return true;
176 }
177 }
178 return false;
179 } else if (lhs.type == c10::SchemaArgType::output) {
180 return output_alias_map_[lhs.index].count(rhs.index);
181 } else {
182 return output_alias_map_[rhs.index].count(lhs.index);
183 }
184}
185
186bool SchemaInfo::may_contain_alias(
187 const c10::SchemaArgument& lhs,
188 const c10::SchemaArgument& rhs,
189 bool bidirectional) {
190 bool basic_check = schema_.may_contain_alias(lhs, rhs) || may_alias(lhs, rhs);
191 if (basic_check) {
192 return true;
193 }
194 if (!alias_maps_current_) {
195 generateAliasMaps();
196 }
197 if (bidirectional) {
198 return mayContainAliasImpl(lhs, rhs) || mayContainAliasImpl(rhs, lhs);
199 } else {
200 return mayContainAliasImpl(lhs, rhs);
201 }
202}
203
204bool SchemaInfo::mayContainAliasImpl(
205 const c10::SchemaArgument& lhs,
206 const c10::SchemaArgument& rhs) {
207 c10::optional<c10::AliasTypeSet> lhsContainedAliasTypeSet =
208 schema_.getAliasTypeSetContainedTypes(schema_.mapTypeToAliasTypeSet(
209 schema_.getCorrectList(lhs.type)[lhs.index].type()));
210 c10::optional<c10::AliasTypeSet> rhsAliasTypeSet =
211 schema_.mapTypeToAliasTypeSet(
212 schema_.getCorrectList(rhs.type)[rhs.index].type());
213 bool types_can_alias =
214 schema_.canAliasTypeSetsAlias(lhsContainedAliasTypeSet, rhsAliasTypeSet);
215 return types_can_alias && containerSet().count(lhs) &&
216 wildcardSet().count(rhs);
217}
218
219void SchemaInfo::ensureConservativity(
220 const std::unordered_set<at::Symbol>& duplicates,
221 const std::vector<c10::Argument>& arguments_list,
222 c10::SchemaArgType type) {
223 for (size_t i = 0; i < arguments_list.size(); i++) {
224 if (arguments_list[i].alias_info()) {
225 for (const auto& set : arguments_list[i].alias_info()->afterSets()) {
226 if (duplicates.count(set)) {
227 wildcard_set_.insert({type, i});
228 }
229 }
230 }
231 }
232}
233
234std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
235 // This list of nondeterministic ops is copied from JIT ir.cpp.
236 static const std::vector<std::string> nondeterministic_op_strings = {
237 "aten::dropout(Tensor input, float p, bool train) -> Tensor",
238 "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
239 "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
240 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
241 "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
242 "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
243 "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
244 "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
245 "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
246 "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
247 "aten::poisson(Tensor self, Generator? generator) -> Tensor",
248 "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
249 "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
250 "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
251 "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
252 "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
253 "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
254 "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
255 "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
256 "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
257 "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
258 "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
259 "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
260
261 std::vector<c10::FunctionSchema> nondeterministic_ops;
262 nondeterministic_ops.reserve(nondeterministic_op_strings.size());
263 for (const std::string& signature : nondeterministic_op_strings) {
264 nondeterministic_ops.emplace_back(torch::jit::parseSchema(signature));
265 }
266
267 return nondeterministic_ops;
268}
269
270std::vector<SchemaSpecialCasePair> SchemaInfo::getTrainingOps() {
271 // This is a list of pairs of ops to sets of strings
272 // where the a boolean variable (either "training",
273 // "train" or "use_input_stats") affects the mutability
274 // of the unorderered set of strings.
275 static const std::vector<std::pair<std::string, std::unordered_set<std::string>>> training_op_pairs =
276 {{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
277 {"running_mean", "running_var"}},
278 {"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor",
279 {"running_mean", "running_var"}},
280 {"aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
281 {"running_mean", "running_var"}},
282 {"aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)",
283 {"running_mean", "running_var"}},
284 {"aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)",
285 {"running_mean", "running_var"}},
286 {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
287 {"running_mean", "running_var"}},
288 {"aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))",
289 {"running_mean", "running_var"}},
290 {"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor",
291 {"noise"}},
292 {"aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)",
293 {"noise"}},
294 {"rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)",
295 {"noise"}}};
296
297 std::vector<SchemaSpecialCasePair> training_ops;
298 training_ops.reserve(training_op_pairs.size());
299 for (const auto& signature : training_op_pairs) {
300 training_ops.emplace_back(
301 torch::jit::parseSchema(signature.first), signature.second);
302 }
303
304 return training_ops;
305}
306
307void SchemaInfo::initSchemaInfo() {
308 if (has_init_) {
309 return;
310 }
311 has_init_ = true;
312
313 std::unordered_set<at::Symbol> duplicates;
314 auto init_schema_arguments = [this, &duplicates](
315 const std::vector<c10::Argument>&
316 arguments_list,
317 c10::SchemaArgType type) {
318 std::unordered_set<at::Symbol> seen;
319 for (size_t i = 0; i < arguments_list.size(); i++) {
320 const c10::Argument& argument = arguments_list[i];
321 if (argument.alias_info()) {
322 if (argument.alias_info()->isWildcardAfter()) {
323 wildcard_set_.insert({type, i});
324 } else {
325 // This check is to ensure that the FunctionSchema will accurately
326 // be represented when calling may_alias and may_contain_alias
327 // on schemas with more than one argument within arguments_list that
328 // shares an alias set.
329 for (const auto& set : argument.alias_info()->afterSets()) {
330 if (seen.count(set)) {
331 TORCH_WARN(
332 set.toQualString(),
333 " appears twice in same argument list which will make aliasing checks more conservative.");
334 duplicates.insert(set);
335 } else {
336 seen.insert(set);
337 }
338 }
339 }
340 }
341 c10::optional<c10::AliasTypeSet> contained_types =
342 schema_.getAliasTypeSetContainedTypes(
343 schema_.mapTypeToAliasTypeSet(argument.type()));
344 if (contained_types && !contained_types->empty()) {
345 container_set_.insert({type, i});
346 }
347 }
348 };
349
350 init_schema_arguments(schema_.arguments(), c10::SchemaArgType::input);
351 init_schema_arguments(schema_.returns(), c10::SchemaArgType::output);
352 ensureConservativity(
353 duplicates, schema_.arguments(), c10::SchemaArgType::input);
354 ensureConservativity(
355 duplicates, schema_.returns(), c10::SchemaArgType::output);
356}
357
358const std::unordered_set<c10::SchemaArgument>& SchemaInfo::wildcardSet() {
359 initSchemaInfo();
360 return wildcard_set_;
361}
362
363const std::unordered_set<c10::SchemaArgument>& SchemaInfo::containerSet() {
364 initSchemaInfo();
365 return container_set_;
366}
367
368void SchemaInfo::generateAliasMaps() {
369 initSchemaInfo();
370
371 alias_maps_current_ = true;
372 input_alias_map_ = std::vector<std::unordered_set<size_t>>(
373 schema_.arguments().size(), std::unordered_set<size_t>());
374 output_alias_map_ = std::vector<std::unordered_set<size_t>>(
375 schema_.returns().size(), std::unordered_set<size_t>());
376
377 // Fills input_alias_map_
378 for (size_t i = 0; i < schema_.arguments().size(); i++) {
379 for (size_t j = i; j < schema_.arguments().size(); j++) {
380 if (i == j) {
381 input_alias_map_[i].insert(i);
382 } else if (
383 value_map_.count(schema_.arguments()[i].name()) &&
384 value_map_.count(schema_.arguments()[j].name())) {
385 if (value_map_[schema_.arguments()[i].name()].isAliasOf(
386 value_map_[schema_.arguments()[j].name()])) {
387 input_alias_map_[i].insert(j);
388 input_alias_map_[j].insert(i);
389 if (wildcard_set_.count({c10::SchemaArgType::input, i})) {
390 wildcard_set_.insert({c10::SchemaArgType::input, j});
391 } else if (wildcard_set_.count({c10::SchemaArgType::input, j})) {
392 wildcard_set_.insert({c10::SchemaArgType::input, i});
393 }
394 }
395 }
396 }
397 }
398
399 // Fills wildcard_set with container created wildcards.
400 // For instance, given the schema:
401 // test(Tensor a, Tensor(*) b, Tensor[] c) -> Tensor
402 // where value(a) is contained in value(c), then a will be added to the
403 // wildcard set where it can now alias b.
404 for (size_t i = 0; i < schema_.arguments().size(); i++) {
405 for (size_t j = 0; j < schema_.arguments().size(); j++) {
406 // if they are already aliasing, there is no way one contains the other
407 if (!input_alias_map_[i].count(j) &&
408 value_map_.count(schema_.arguments()[i].name()) &&
409 value_map_.count(schema_.arguments()[j].name())) {
410 c10::IValue::HashAliasedIValues subValues;
411 value_map_[schema_.arguments()[i].name()].getSubValues(subValues);
412 if (subValues.count(value_map_[schema_.arguments()[j].name()])) {
413 wildcard_set_.insert({c10::SchemaArgType::input, j});
414 }
415 }
416 }
417 }
418
419 // Fills output_alias_map_
420 for (size_t i = 0; i < schema_.arguments().size(); i++) {
421 for (size_t j = 0; j < schema_.returns().size(); j++) {
422 if (schema_.may_alias(
423 {c10::SchemaArgType::input, i},
424 {c10::SchemaArgType::output, j})) {
425 if (wildcard_set_.count({c10::SchemaArgType::input, i})) {
426 wildcard_set_.insert({c10::SchemaArgType::output, j});
427 }
428 output_alias_map_[j].insert(
429 input_alias_map_[i].begin(), input_alias_map_[i].end());
430 }
431 }
432 }
433}
434
435} // namespace utils
436} // namespace torch
437