1 | #pragma once |
2 | |
3 | #include <c10/util/StringUtil.h> |
4 | #include <c10/util/string_view.h> |
5 | #include <c10/util/irange.h> |
6 | #include <ATen/core/jit_type.h> |
7 | #include <ATen/core/symbol.h> |
8 | #include <ATen/core/ivalue.h> |
9 | #include <ATen/core/alias_info.h> |
10 | #include <ATen/core/operator_name.h> |
11 | #include <ATen/core/dispatch/OperatorOptions.h> |
12 | #include <unordered_map> |
13 | |
14 | namespace c10 { |
15 | |
16 | // schema as used in the compiler for resolving function calls and reporting |
17 | // errors. These objects should be constructed from C10 schema once those |
18 | // are available. |
19 | |
20 | struct Argument; |
21 | struct FunctionSchema; |
22 | |
23 | using AliasTypeSet = std::vector<TypePtr>; |
24 | |
25 | bool operator==(const Argument& lhs, const Argument& rhs); |
26 | |
27 | struct Argument { |
28 | Argument( |
29 | std::string name = "" , |
30 | TypePtr type = nullptr, |
31 | c10::optional<int32_t> N = c10::nullopt, |
32 | c10::optional<IValue> default_value = c10::nullopt, |
33 | bool kwarg_only = false, |
34 | c10::optional<AliasInfo> alias_info = c10::nullopt) |
35 | : Argument(name, type, type, N, default_value, kwarg_only, alias_info) {} |
36 | |
37 | Argument( |
38 | std::string name, |
39 | TypePtr fake_type, |
40 | TypePtr real_type, |
41 | c10::optional<int32_t> N = c10::nullopt, |
42 | c10::optional<IValue> default_value = c10::nullopt, |
43 | bool kwarg_only = false, |
44 | c10::optional<AliasInfo> alias_info = c10::nullopt) |
45 | : name_(std::move(name)), |
46 | type_(fake_type ? std::move(fake_type) : TensorType::get()), |
47 | real_type_(real_type ? std::move(real_type) : type_), |
48 | N_(std::move(N)), |
49 | default_value_(std::move(default_value)), |
50 | alias_info_(alias_info ? std::make_unique<AliasInfo>(std::move(*alias_info)) : nullptr), |
51 | kwarg_only_(kwarg_only) { |
52 | // this is an softly-enforced invariant for out arguments. |
53 | bool is_alias = alias_info_ != nullptr && alias_info_->isWrite(); |
54 | is_out_ = kwarg_only_ && is_alias; |
55 | } |
56 | |
57 | Argument(Argument&& rhs) noexcept = default; |
58 | |
59 | Argument(const Argument& rhs) |
60 | : name_(rhs.name_), |
61 | type_(rhs.type_), |
62 | real_type_(rhs.real_type_), |
63 | N_(rhs.N_), |
64 | default_value_(rhs.default_value_), |
65 | alias_info_(rhs.alias_info_ ? std::make_unique<AliasInfo>(*rhs.alias_info_) : nullptr), |
66 | kwarg_only_(rhs.kwarg_only_), |
67 | is_out_(rhs.is_out_) {} |
68 | |
69 | Argument& operator=(Argument&& rhs) = default; |
70 | |
71 | Argument& operator=(const Argument& rhs) { |
72 | if (this != &rhs) { |
73 | name_ = rhs.name_; |
74 | type_ = rhs.type_; |
75 | real_type_ = rhs.real_type_; |
76 | N_ = rhs.N_; |
77 | default_value_ = rhs.default_value_; |
78 | alias_info_ = rhs.alias_info_ ? std::make_unique<AliasInfo>(*rhs.alias_info_) : nullptr; |
79 | kwarg_only_ = rhs.kwarg_only_; |
80 | is_out_ = rhs.is_out_; |
81 | } |
82 | return *this; |
83 | } |
84 | |
85 | const std::string& name() const { |
86 | return name_; |
87 | } |
88 | const TypePtr& type() const { |
89 | return type_; |
90 | } |
91 | // if type() is non-null, this is guaranteed to be non-null (if no real |
92 | // type was provided, this takes on type()'s value) |
93 | const TypePtr& real_type() const { |
94 | return real_type_; |
95 | } |
96 | c10::optional<int32_t> N() const { |
97 | return N_; |
98 | } |
99 | const c10::optional<IValue>& default_value() const { |
100 | return default_value_; |
101 | } |
102 | bool kwarg_only() const { |
103 | return kwarg_only_; |
104 | } |
105 | |
106 | bool is_out() const { |
107 | return is_out_; |
108 | } |
109 | |
110 | C10_NODISCARD const AliasInfo* alias_info() const { |
111 | return alias_info_.get(); |
112 | } |
113 | |
114 | bool is_inferred_type() const { |
115 | bool is_inferred_type = false; |
116 | TORCH_INTERNAL_ASSERT(type_); |
117 | if (auto pt = type_->cast<TensorType>()) { |
118 | if (pt->isInferredType()) { |
119 | is_inferred_type = true; |
120 | } |
121 | } |
122 | return is_inferred_type; |
123 | } |
124 | |
125 | std::string formatTypeMismatchMsg(const std::string& actual_type) const { |
126 | std::string inferred_type_hint; |
127 | if (is_inferred_type()) { |
128 | inferred_type_hint = c10::str( |
129 | "Inferred '" , |
130 | name(), |
131 | "' to be of type 'Tensor' " , |
132 | "because it was not annotated with an explicit type.\n" ); |
133 | } |
134 | return c10::str( |
135 | "Expected a value of type '" , |
136 | type()->repr_str(), |
137 | "' for argument '" , |
138 | name(), |
139 | "' but instead found type '" , |
140 | actual_type, |
141 | "'.\n" , |
142 | inferred_type_hint); |
143 | } |
144 | |
145 | Argument cloneWithType(TypePtr new_type) const { |
146 | return Argument( |
147 | name_, |
148 | std::move(new_type), |
149 | N_, |
150 | default_value_, |
151 | kwarg_only_, |
152 | alias_info_ ? c10::optional<AliasInfo>(*alias_info_) : c10::nullopt); |
153 | } |
154 | |
155 | // this function checks whether this Argument is backward compatible with |
156 | // the old one. we consider the following cases are backward compatible: |
157 | // 1) two arguments are equal |
158 | // 2) this arg's type should be subtype of old |
159 | // 3) this arg must provide the same default value if old arg has one, |
160 | bool isBackwardCompatibleWith( |
161 | const Argument& old, |
162 | std::ostream* why_not=nullptr) const; |
163 | |
164 | // this function checks whether this Argument is forward compatible with |
165 | // the old one. we consider the following cases are forward compatible: |
166 | // 1) two arguments are equal |
167 | // 2) this arg's type should be subtype of old |
168 | // 3) this arg must provide the same default value if old arg has one, |
169 | bool isForwardCompatibleWith( |
170 | const Argument& old, |
171 | std::ostream* why_not = nullptr) const; |
172 | |
173 | private: |
174 | std::string name_; |
175 | TypePtr type_; |
176 | TypePtr real_type_; // this is ScalarType, not int, e.g. |
177 | // for list types, an optional statically known length for the list |
178 | // e.g. for int[3]: type = ListType::ofInts(), N = 3 |
179 | // If present, this will allow scalars to be broadcast to this length to |
180 | // become a list. |
181 | c10::optional<int32_t> N_; |
182 | |
183 | c10::optional<IValue> default_value_; |
184 | // AliasInfo is huge, so let's only allocate memory for it if |
185 | // necessary (which it isn't during schema parsing on startup, to |
186 | // give a pertinent example). |
187 | std::unique_ptr<AliasInfo> alias_info_; |
188 | // is this only specifiable as a keyword argument? |
189 | bool kwarg_only_; |
190 | // marks if the argument is out variant of the schema |
191 | bool is_out_; |
192 | }; |
193 | |
194 | inline bool operator==(const Argument& lhs, const Argument& rhs) { |
195 | return lhs.name() == rhs.name() |
196 | && *lhs.type() == *rhs.type() |
197 | && lhs.N() == rhs.N() |
198 | && lhs.default_value() == rhs.default_value() |
199 | && lhs.kwarg_only() == rhs.kwarg_only() |
200 | && (lhs.alias_info() == rhs.alias_info() |
201 | || (lhs.alias_info() != nullptr && rhs.alias_info() != nullptr |
202 | && *lhs.alias_info() == *rhs.alias_info())); |
203 | } |
204 | |
205 | inline bool operator!=(const Argument& lhs, const Argument& rhs) { |
206 | return !(lhs == rhs); |
207 | } |
208 | |
209 | enum struct TORCH_API SchemaArgType { input, output }; |
210 | |
211 | /** |
212 | * struct SchemaArgument |
213 | * |
214 | * Structure used to represent arguments or returns for a schema. |
215 | */ |
216 | struct TORCH_API SchemaArgument { |
217 | SchemaArgType type; |
218 | size_t index; |
219 | SchemaArgument(SchemaArgType tpe, size_t idx) : type(tpe), index(idx) {} |
220 | bool operator==(const SchemaArgument& rhs) const { |
221 | return type == rhs.type && index == rhs.index; |
222 | } |
223 | }; |
224 | |
225 | bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs); |
226 | |
227 | struct TORCH_API FunctionSchema { |
228 | FunctionSchema( |
229 | std::string name, |
230 | std::string overload_name, |
231 | std::vector<Argument> arguments, |
232 | std::vector<Argument> returns, |
233 | bool is_vararg = false, |
234 | bool is_varret = false) |
235 | : name_({std::move(name), std::move(overload_name)}), |
236 | arguments_(std::move(arguments)), |
237 | returns_(std::move(returns)), |
238 | is_vararg_(is_vararg), |
239 | is_varret_(is_varret) { |
240 | checkSchema(); |
241 | } |
242 | |
243 | FunctionSchema( |
244 | Symbol name, |
245 | std::string overload_name, |
246 | std::vector<Argument> arguments, |
247 | std::vector<Argument> returns, |
248 | bool is_vararg = false, |
249 | bool is_varret = false) |
250 | : FunctionSchema( |
251 | name.toQualString(), |
252 | std::move(overload_name), |
253 | std::move(arguments), |
254 | std::move(returns), |
255 | is_vararg, |
256 | is_varret) { |
257 | checkSchema(); |
258 | } |
259 | |
260 | // Checks whether this schema is backward compatible with the old one. |
261 | // The following conditions must be true: |
262 | // [Function structure] The new schema's name, overload-name, varargs, and |
263 | // return arity are the same. |
264 | // [Output Narrowing] The new schema's output type must be the same class |
265 | // or inherit from the old schema's output type. |
266 | // [Argument count] The new schema must have at least as many arguments as |
267 | // the old schema (considering the list of positional and kwargs). |
268 | // [Arg Compatibility] Every argument in the old schema has a corresponding |
269 | // argument in the new schema that: |
270 | // * is at the same position. |
271 | // * has the same name. |
272 | // * is either positional, or kwarg and the old argument was kwarg. |
273 | // * has the same type, or the old argument's type inherits from the |
274 | // new argument's type. |
275 | // [Default Values] Every new argument must have a default value. |
276 | // E.g. |
277 | // OK f_new(a, b, c=1) => f_old(a, b) |
278 | // NOK f_new(a, c=1, *, b) => f_old(a, *, b) |
279 | // OK f_new(a, b, *, c) => f_old(a, *, b, c) |
280 | // NOK f_new(a, *, b, c) -> f_old(a, b, *, c) |
281 | // NOK f_new(a, *, c, b) => f_old(a, *, b, c) |
282 | // OK f_new(a, *, b, c, d=1) => f_old(a, *, b, c) |
283 | bool isBackwardCompatibleWith( |
284 | const FunctionSchema& old, |
285 | std::ostream* why_not = nullptr) const; |
286 | |
287 | // Checks whether this schema is forward compatible with the old one. |
288 | // The following conditions must be true: |
289 | // [Function structure] The new schema's name, overload-name, varargs, and |
290 | // return arity are the same. |
291 | // [Output Narrowing] The new schema's output type must be the same class |
292 | // or inherit from the old schema's output type. |
293 | // [Arg Compatibility] Every argument in the old schema has a corresponding |
294 | // argument in the new schema that: |
295 | // * is at the same position. |
296 | // * has the same name. |
297 | // * is either positional, or kwarg and the old argument was kwarg. |
298 | // * has the same type, or the old argument's type inherits from the |
299 | // new argument's type. |
300 | // [Default Values] Every new argument must have a default value. |
301 | // Each default value type should NOT be a container type. |
302 | // [Positioning] All defaults arguments MUST go after either old |
303 | // default arguments or the end of positional arguments |
304 | // and right BEFORE all out arguments |
305 | bool isForwardCompatibleWith( |
306 | const FunctionSchema& old, |
307 | std::ostringstream& why_not) const; |
308 | |
309 | private: |
310 | OperatorName name_; |
311 | std::vector<Argument> arguments_; |
312 | std::vector<Argument> returns_; |
313 | // if true then this schema takes an arbitrary number of additional arguments |
314 | // after the argument specified in arguments |
315 | // currently this is used primarily to represent 'primitive' operators whose |
316 | // arguments are not checked by schema |
317 | bool is_vararg_; |
318 | bool is_varret_; |
319 | |
320 | // if no alias information is directly specified, what kind of "default" |
321 | // alias information should we infer? |
322 | // NB: due to alias analysis kind merging, this may be nullopt. Eventually |
323 | // this should always be set no matter what |
324 | c10::optional<AliasAnalysisKind> alias_kind_; |
325 | |
326 | template <typename T> |
327 | void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const; |
328 | |
329 | void checkSchema() const { |
330 | bool seen_default_arg = false; |
331 | for (const auto& arg : arguments()) { |
332 | if (arg.default_value()) { |
333 | seen_default_arg = true; |
334 | } else { |
335 | // we have historically serialized broadcasting lists wo/default values, |
336 | // so to not break BC allow lists here |
337 | if (arg.type()->kind() == ListType::Kind) { |
338 | continue; |
339 | } |
340 | TORCH_INTERNAL_ASSERT( |
341 | !seen_default_arg || arg.kwarg_only(), |
342 | "Non-default positional argument follows default argument. Parameter " , |
343 | arg.name(), |
344 | " in " , |
345 | *this); |
346 | } |
347 | } |
348 | } |
349 | |
350 | public: |
351 | |
352 | void dump() const; |
353 | |
354 | const OperatorName& operator_name() const { |
355 | return name_; |
356 | } |
357 | const std::string& name() const { |
358 | return name_.name; |
359 | } |
360 | const std::string& overload_name() const { |
361 | return name_.overload_name; |
362 | } |
363 | const std::vector<Argument>& arguments() const { |
364 | return arguments_; |
365 | } |
366 | const std::vector<Argument>& returns() const { |
367 | return returns_; |
368 | } |
369 | bool is_vararg() const { |
370 | return is_vararg_; |
371 | } |
372 | bool is_varret() const { |
373 | return is_varret_; |
374 | } |
375 | bool is_aliasing(const c10::SchemaArgument &argument) const { |
376 | TORCH_INTERNAL_ASSERT( |
377 | argument.index < getCorrectList(argument.type).size(), |
378 | "Invalid index for schema." ); |
379 | const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); |
380 | return aliasInfo; |
381 | } |
382 | bool is_mutable() const { |
383 | return std::any_of( |
384 | arguments_.cbegin(), arguments_.cend(), [](const Argument& arg) { |
385 | const AliasInfo* aliasInfo = arg.alias_info(); |
386 | return aliasInfo && aliasInfo->isWrite(); |
387 | }); |
388 | } |
389 | bool is_mutable(const c10::SchemaArgument &argument) const { |
390 | TORCH_INTERNAL_ASSERT( |
391 | argument.index < getCorrectList(argument.type).size(), |
392 | "Invalid index for schema." ); |
393 | const AliasInfo* aliasInfo = getCorrectList(argument.type)[argument.index].alias_info(); |
394 | return aliasInfo && aliasInfo->isWrite(); |
395 | } |
396 | bool is_mutable(c10::string_view name) const { |
397 | c10::optional<int> index = argumentIndexWithName(name); |
398 | TORCH_INTERNAL_ASSERT( |
399 | index != c10::nullopt, "Schema has no argument named " , name); |
400 | |
401 | return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)}); |
402 | } |
403 | |
404 | // Returns whether lhs and rhs may alias directly. |
405 | // This does not account for cases where lhs or rhs are a container that |
406 | // may contain elements that alias the other argument. |
407 | // FunctionSchema::may_contain_alias will include that functionality. |
408 | bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const; |
409 | |
410 | // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container |
411 | // that may contain elements that alias the other argument. |
412 | // bidirectional = false only returns whether lhs may contain an alias of rhs |
413 | // while bidirectional = true returns both directions. |
414 | bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const; |
415 | |
416 | // Returns whether the two AliasTypeSets contain any similarities |
417 | // ie: whether the two type sets can alias. |
418 | bool canAliasTypeSetsAlias(const c10::optional<AliasTypeSet> &lhs, const c10::optional<AliasTypeSet> &rhs) const; |
419 | |
420 | // Recursively Finds all contained types within the AliasTypeSet. |
421 | c10::optional<AliasTypeSet> getAliasTypeSetContainedTypes(const c10::optional<AliasTypeSet> &aliasTypeSet) const; |
422 | |
423 | // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp. |
424 | // Used to map types to a type such that all types that can alias will be mapped to the same type. |
425 | // For example, calling this method on 'Optional[List[int]]' is the same as calling this method |
426 | // on 'List[int]'. |
427 | c10::optional<AliasTypeSet> mapTypeToAliasTypeSet(const TypePtr& type) const; |
428 | |
429 | // Returns either arguments() or returns() depending on the SchemaArgType |
430 | // output => returns(), input => arguments() |
431 | const std::vector<Argument>& getCorrectList(SchemaArgType type) const; |
432 | |
433 | c10::optional<int> argumentIndexWithName(c10::string_view name) const { |
434 | for (const auto i : c10::irange(arguments().size())) { |
435 | if(name == arguments()[i].name()) |
436 | return i; |
437 | } |
438 | return c10::nullopt; |
439 | } |
440 | FunctionSchema cloneWithName(std::string name, std::string overload_name) const { |
441 | return FunctionSchema( |
442 | std::move(name), |
443 | std::move(overload_name), |
444 | arguments(), |
445 | returns(), |
446 | is_vararg(), |
447 | is_varret() |
448 | ); |
449 | } |
450 | FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const { |
451 | return FunctionSchema( |
452 | name(), |
453 | overload_name(), |
454 | std::move(new_arguments), |
455 | returns(), |
456 | is_vararg(), |
457 | is_varret()); |
458 | } |
459 | FunctionSchema cloneWithReturns(std::vector<Argument> new_returns) const { |
460 | return FunctionSchema( |
461 | name(), |
462 | overload_name(), |
463 | arguments(), |
464 | std::move(new_returns), |
465 | is_vararg(), |
466 | is_varret()); |
467 | } |
468 | |
469 | std::string formatTypeMismatchMsg( |
470 | const Argument& expected, |
471 | const std::string& actual_type, |
472 | c10::optional<size_t> position = c10::nullopt, |
473 | c10::optional<std::string> value = c10::nullopt) const; |
474 | |
475 | FunctionSchema cloneWithRemappedTypes( |
476 | const std::function<TypePtr(TypePtr)> type_map) const; |
477 | |
478 | FunctionSchema cloneWithRealTypes(bool with_symint=true) const; |
479 | |
480 | // Check that inputs have the correct types and appends any missing default |
481 | // values. |
482 | template <typename T = c10::PlatformType> |
483 | void checkAndNormalizeInputs( |
484 | std::vector<IValue>& inputs, |
485 | const std::unordered_map<std::string, IValue>& kwargs = |
486 | std::unordered_map<std::string, IValue>{}) const; |
487 | |
488 | std::string findErrorInKwargs(const std::vector<std::string>& kwargs) const; |
489 | |
490 | bool hasAnyAliasInfo() const { |
491 | for (const auto& arg : arguments_) { |
492 | if (arg.alias_info() != nullptr) { |
493 | return true; |
494 | } |
495 | } |
496 | for (const auto& ret : returns_) { |
497 | if (ret.alias_info() != nullptr) { |
498 | return true; |
499 | } |
500 | } |
501 | return false; |
502 | } |
503 | |
504 | |
505 | // TODO remove the mutation here |
506 | bool isDefaultAliasAnalysisKind() const { |
507 | return !alias_kind_; |
508 | } |
509 | AliasAnalysisKind aliasAnalysis() const { |
510 | return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE); |
511 | } |
512 | void setAliasAnalysis(AliasAnalysisKind v) { |
513 | alias_kind_ = v; |
514 | } |
515 | |
516 | c10::optional<c10::string_view> getNamespace() const { |
517 | return name_.getNamespace(); |
518 | } |
519 | |
520 | // Returns true if we successfully set the namespace (as there |
521 | // was none set, and false otherwise) |
522 | bool setNamespaceIfNotSet(const char* ns) { |
523 | return name_.setNamespaceIfNotSet(ns); |
524 | } |
525 | |
526 | // can a function with this schema be substituted for a function of rhs's |
527 | // schema and have the program typecheck? |
528 | // as_method - if true, treat this schema as a method and ignore |
529 | // the first argument, which will be the object in both cases |
530 | bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const; |
531 | }; |
532 | |
533 | inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) { |
534 | return lhs.name() == rhs.name() |
535 | && lhs.overload_name() == rhs.overload_name() |
536 | && lhs.arguments() == rhs.arguments() |
537 | && lhs.returns() == rhs.returns() |
538 | && lhs.is_vararg() == rhs.is_vararg() |
539 | && lhs.is_varret() == rhs.is_varret(); |
540 | } |
541 | |
542 | inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) { |
543 | return !(lhs == rhs); |
544 | } |
545 | |
546 | // print out Argument, which is compatible with FunctionSchema parser |
547 | // full format: Type(alias)? name=default_value |
548 | inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { |
549 | |
550 | // for adjusting the ? position. |
551 | // in schema, we have Tensor?(a!) input, and t(a!)?. |
552 | // however, t?(a!) doesn't work with schema parser. |
553 | // so we always use Type(alias)? format |
554 | // real_type versus fake_type: in order to be compatible with FunctionSchema |
555 | // parser, printing an argument with either MemoryFormat or Layout type should |
556 | // give us the original schema string, hence printing out real_type. |
557 | auto type = arg.real_type(); |
558 | bool is_opt = type->kind() == OptionalType::Kind; |
559 | auto unopt_type = is_opt ? type->castRaw<OptionalType>()->getElementType() : type; |
560 | |
561 | if (unopt_type->kind() == ListType::Kind) { |
562 | // sized lists get size N from arg, not type |
563 | auto list = unopt_type->cast<c10::ListType>(); |
564 | out << list->getElementType()->str(); |
565 | if (arg.alias_info() && !arg.alias_info()->containedTypes().empty()){ |
566 | out << arg.alias_info()->containedTypes()[0]; |
567 | } |
568 | std::string N = "" ; |
569 | if (arg.N()) { |
570 | N = std::to_string(*arg.N()); |
571 | } |
572 | out << "[" << N << "]" ; |
573 | } else { |
574 | out << unopt_type->str(); |
575 | } |
576 | |
577 | // print alias info if it has beforeSets. |
578 | if (arg.alias_info() && !arg.alias_info()->beforeSets().empty()) { |
579 | out << *arg.alias_info(); |
580 | } |
581 | |
582 | if (is_opt) { |
583 | out << "?" ; |
584 | } |
585 | |
586 | if (!arg.name().empty()) { |
587 | out << " " << arg.name(); |
588 | } |
589 | |
590 | if (arg.default_value()) { |
591 | out << "=" ; |
592 | if ((type->kind() == c10::TypeKind::StringType || |
593 | unopt_type->kind() == c10::TypeKind::StringType) && |
594 | arg.default_value().value().isString()) { |
595 | printQuotedString(out, arg.default_value().value().toStringRef()); |
596 | } else if (type->kind() == TypeKind::ListType && type->castRaw<ListType>()->getElementType()->kind() == c10::TypeKind::IntType) { |
597 | // We want to faithfully replicate JIT schema. |
598 | // in native_functions.yaml defaults for int arrays with a single value always look like |
599 | // int[2] stride=1 |
600 | // instead of |
601 | // int[2] stride=[1, 1] |
602 | auto default_val = arg.default_value().value().toIntList(); |
603 | if (default_val.size() > 1) { |
604 | auto all_defaults_the_same = true; |
605 | for (const auto i : c10::irange(1, default_val.size())) { |
606 | if (default_val[0] != default_val[i]) all_defaults_the_same = false; |
607 | } |
608 | if (all_defaults_the_same) { |
609 | out << default_val[0]; |
610 | } else { |
611 | out << arg.default_value().value(); |
612 | } |
613 | } else { |
614 | out << arg.default_value().value(); |
615 | } |
616 | } else { |
617 | out << arg.default_value().value(); |
618 | } |
619 | } |
620 | |
621 | return out; |
622 | } |
623 | |
624 | inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema); |
625 | |
626 | inline std::string toString(const FunctionSchema& schema) { |
627 | std::ostringstream str; |
628 | str << schema; |
629 | return str.str(); |
630 | } |
631 | |
632 | } // namespace c10 |
633 | |
634 | namespace std { |
635 | template<> |
636 | struct hash<c10::SchemaArgument> { |
637 | size_t operator()(const c10::SchemaArgument& arg) const |
638 | { |
639 | return c10::hash_combine(std::hash<size_t>()(arg.index), std::hash<size_t>()(static_cast<std::size_t>(arg.type))); |
640 | } |
641 | }; |
642 | template<> |
643 | struct hash<c10::Argument> { |
644 | size_t operator()(const c10::Argument& arg) const |
645 | { |
646 | auto hash = std::hash<std::string>{}(arg.name()); |
647 | auto type_hash = std::hash<c10::TypePtr>{}(arg.type()); |
648 | auto kwarg_only_hash = std::hash<bool>{}(arg.kwarg_only()); |
649 | hash = c10::hash_combine(hash, type_hash); |
650 | hash = c10::hash_combine(hash, kwarg_only_hash); |
651 | // hashing optional fields if they exist |
652 | if (arg.default_value()) { |
653 | auto default_value_hash = c10::hash<c10::IValue>{}(arg.default_value().value()); |
654 | hash = c10::hash_combine(hash, default_value_hash); |
655 | } |
656 | if (arg.N()) { |
657 | auto N_hash = std::hash<int64_t>{}(*arg.N()); |
658 | hash = c10::hash_combine(hash, N_hash); |
659 | } |
660 | if (arg.alias_info()) { |
661 | auto alias_info_hash = std::hash<c10::AliasInfo>{}(*arg.alias_info()); |
662 | hash = c10::hash_combine(hash, alias_info_hash); |
663 | } |
664 | return hash; |
665 | } |
666 | }; |
667 | template<> |
668 | struct hash<c10::FunctionSchema> { |
669 | size_t operator()(const c10::FunctionSchema& schema) const |
670 | { |
671 | auto hash = std::hash<c10::OperatorName>{}(schema.operator_name()); |
672 | auto args_hash = c10::hash<std::vector<c10::Argument>>{}(schema.arguments()); |
673 | auto returns_hash = c10::hash<std::vector<c10::Argument>>{}(schema.returns()); |
674 | auto is_vararg_hash = std::hash<bool>{}(schema.is_vararg()); |
675 | auto is_varret_hash = std::hash<bool>{}(schema.is_varret()); |
676 | hash = c10::hash_combine(hash, args_hash); |
677 | hash = c10::hash_combine(hash, returns_hash); |
678 | hash = c10::hash_combine(hash, is_vararg_hash); |
679 | hash = c10::hash_combine(hash, is_varret_hash); |
680 | return hash; |
681 | } |
682 | }; |
683 | } // namespace std |
684 | |
685 | |
686 | #include <ATen/core/function_schema_inl.h> // IWYU pragma: keep |
687 | |