1 | #pragma once |
2 | #include <iostream> |
3 | |
4 | // note: windows build doesn't find symbols in operator files unless |
5 | // this is a header file |
6 | |
7 | namespace c10 { |
8 | |
9 | inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { |
10 | // eventually this should look almost identical to python arg parser, but |
11 | // it is simpler for now to work directly on this schema |
12 | |
13 | out << schema.name(); |
14 | if (!schema.overload_name().empty()) { |
15 | out << "." << schema.overload_name(); |
16 | } |
17 | out << "(" ; |
18 | |
19 | bool seen_kwarg_only = false; |
20 | for (const auto i : c10::irange(schema.arguments().size())) { |
21 | if (i > 0) out << ", " ; |
22 | if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) { |
23 | out << "*, " ; |
24 | seen_kwarg_only = true; |
25 | } |
26 | out << schema.arguments()[i]; |
27 | } |
28 | |
29 | if(schema.is_vararg()) { |
30 | if(!schema.arguments().empty()) |
31 | out << ", " ; |
32 | out << "..." ; |
33 | } |
34 | |
35 | out << ") -> " ; |
36 | |
37 | const auto& returns = schema.returns(); |
38 | |
39 | /* |
40 | * We should skip parenthesis if we return a single item and it's not varret, |
41 | * or we return nothing but varret. |
42 | * |
43 | * Need special handling for schema |
44 | * aten::items.str(Dict(str, t) self) -> (str,t)[] |
45 | * Even though this schema returns a single item, we need add parenthesis. |
46 | * The is necessary so the printed schema can be parsed by the C++ SchemaParser |
47 | * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly |
48 | * treat the return type as a tuple. An alternative is to enhance the Lexer |
49 | * to lookahead multiple tokens to accurately decide if the return type is |
50 | * a tuple. |
51 | */ |
52 | bool need_paren = !( |
53 | (returns.size() == 1 && !schema.is_varret()) || |
54 | (returns.empty() && schema.is_varret())); |
55 | |
56 | if (returns.size() == 1 && !schema.is_varret()) { |
57 | std::stringstream return_ss; |
58 | return_ss << returns.at(0); |
59 | auto return_str = return_ss.str(); |
60 | |
61 | // enclosing the single return item with parenthesis if the return type |
62 | // starts with a left parenthesis. |
63 | // |
64 | // There are 2 cases |
65 | // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'. |
66 | // without the extra parenthesis, the c++ schem parser can not parse it. |
67 | // 2. something like '-> ((str, str))'. Need extra parenthesis so the return |
68 | // type is a single tuple rather than two strings. |
69 | // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about |
70 | // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15) |
71 | // also covers this case. |
72 | if (!return_str.empty() && return_str.front() == '(') { |
73 | need_paren = true; |
74 | } |
75 | } |
76 | |
77 | if (need_paren) { |
78 | out << "(" ; |
79 | } |
80 | for (const auto i : c10::irange(returns.size())) { |
81 | if (i > 0) { |
82 | out << ", " ; |
83 | } |
84 | out << returns.at(i); |
85 | } |
86 | if (schema.is_varret()) { |
87 | if (!returns.empty()) { |
88 | out << ", " ; |
89 | } |
90 | out << "..." ; |
91 | } |
92 | if (need_paren) { |
93 | out << ")" ; |
94 | } |
95 | return out; |
96 | } |
97 | |
98 | inline size_t findFirstOutArg(const std::vector<Argument>& args) { |
99 | // find the start of out args in the schema |
100 | for (const auto out_start_idx : c10::irange(args.size())) { |
101 | if (args.at(out_start_idx).is_out()) { |
102 | return out_start_idx; |
103 | } |
104 | } |
105 | return args.size(); |
106 | } |
107 | |
108 | inline bool Argument::isBackwardCompatibleWith( |
109 | const Argument& old, |
110 | std::ostream* why_not) const { |
111 | const Argument* lhs = this; |
112 | const Argument* rhs = &old; |
113 | if (!(lhs->name() == rhs->name() |
114 | && lhs->N() == rhs->N() |
115 | && (lhs->alias_info() == rhs->alias_info() |
116 | || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr |
117 | && *lhs->alias_info() == *rhs->alias_info())))) { |
118 | return false; |
119 | } |
120 | if (lhs->kwarg_only() && !rhs->kwarg_only()) { |
121 | return false; |
122 | } |
123 | if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) { |
124 | return false; |
125 | } |
126 | if (rhs->default_value().has_value() && |
127 | lhs->default_value() != rhs->default_value()) { |
128 | return false; |
129 | } |
130 | return true; |
131 | } |
132 | |
133 | inline bool Argument::isForwardCompatibleWith( |
134 | const Argument& old, |
135 | std::ostream* why_not) const { |
136 | const Argument* lhs = this; |
137 | const Argument* rhs = &old; |
138 | if (!(lhs->name() == rhs->name() |
139 | && lhs->N() == rhs->N() |
140 | && (lhs->alias_info() == rhs->alias_info() |
141 | || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr |
142 | && *lhs->alias_info() == *rhs->alias_info())))) { |
143 | return false; |
144 | } |
145 | if (lhs->kwarg_only() && !rhs->kwarg_only()) { |
146 | return false; |
147 | } |
148 | if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) { |
149 | return false; |
150 | } |
151 | if (rhs->default_value().has_value() && |
152 | lhs->default_value() != rhs->default_value()) { |
153 | return false; |
154 | } |
155 | if (lhs->default_value().has_value() && !rhs->default_value().has_value()) { |
156 | return false; |
157 | } |
158 | return true; |
159 | } |
160 | |
161 | inline std::string FunctionSchema::formatTypeMismatchMsg( |
162 | const Argument& expected, |
163 | const std::string& actual_type, |
164 | c10::optional<size_t> position, |
165 | c10::optional<std::string> value) const { |
166 | std::string position_str; |
167 | if (position) { |
168 | position_str = c10::str("Position: " , *position, "\n" ); |
169 | } |
170 | std::string value_str; |
171 | if (value) { |
172 | value_str = c10::str("Value: " , *value, "\n" ); |
173 | } |
174 | return c10::str( |
175 | name(), |
176 | "() " , |
177 | expected.formatTypeMismatchMsg(actual_type), |
178 | position_str, |
179 | value_str, |
180 | "Declaration: " , |
181 | *this); |
182 | } |
183 | |
184 | inline bool FunctionSchema::isBackwardCompatibleWith( |
185 | const FunctionSchema& old, |
186 | std::ostream* why_not) const { |
187 | if (!(name() == old.name() |
188 | && overload_name() == old.overload_name() |
189 | // we are conservative on is_vararg and is_varret, |
190 | // since they are only used by internal operators |
191 | && is_vararg() == old.is_vararg() |
192 | && is_varret() == old.is_varret() |
193 | && returns().size() == old.returns().size() |
194 | && arguments().size() >= old.arguments().size())) { |
195 | return false; |
196 | } |
197 | for (const auto i : c10::irange(returns().size())) { |
198 | // Backwards compatibility requires covariance on argument types |
199 | // (i.e. more generic), and contravariance on return types (i.e. |
200 | // more specific). |
201 | if (!old.returns().at(i).isBackwardCompatibleWith( |
202 | returns().at(i), |
203 | why_not)) { |
204 | return false; |
205 | } |
206 | } |
207 | |
208 | // we want to test both out and default args separately |
209 | size_t old_out_start_idx = findFirstOutArg(old.arguments()); |
210 | size_t new_out_start_idx = findFirstOutArg(arguments()); |
211 | |
212 | // make sure among the default args, they are backward compatible |
213 | for (const auto i : c10::irange(old_out_start_idx)) { |
214 | if (!arguments().at(i).isBackwardCompatibleWith( |
215 | old.arguments().at(i), why_not)) { |
216 | return false; |
217 | } |
218 | } |
219 | |
220 | // Validate that all new arguments provided has a default value |
221 | for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) { |
222 | if (!arguments().at(i).default_value()) { |
223 | if (why_not) { |
224 | *why_not |
225 | << "Function schema not backward compatible since the new argument '" |
226 | << arguments().at(i).name() << "' of type " |
227 | << arguments().at(i).type()->str() |
228 | << " did not provide a default value." ; |
229 | } |
230 | return false; |
231 | } |
232 | } |
233 | |
234 | // now compare the out args |
235 | for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) { |
236 | if (!arguments() |
237 | .at(i - old_out_start_idx + new_out_start_idx) |
238 | .isBackwardCompatibleWith(old.arguments().at(i), why_not)) { |
239 | return false; |
240 | } |
241 | } |
242 | |
243 | return true; |
244 | } |
245 | |
246 | inline bool FunctionSchema::isForwardCompatibleWith( |
247 | const FunctionSchema& old, |
248 | std::ostringstream& why_not) const { |
249 | if (!(name() == old.name() && |
250 | overload_name() == old.overload_name() |
251 | // we are conservative on is_vararg and is_varret, |
252 | // since they are only used by internal operators |
253 | && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() && |
254 | returns().size() == old.returns().size())) { |
255 | return false; |
256 | } |
257 | |
258 | // we want to test both out and default args separately |
259 | size_t old_out_start_idx = findFirstOutArg(old.arguments()); |
260 | size_t new_out_start_idx = findFirstOutArg(arguments()); |
261 | |
262 | if (old.arguments().size() - old_out_start_idx != |
263 | arguments().size() - new_out_start_idx) { |
264 | if (why_not) { |
265 | why_not << "Function schema should have the " |
266 | << "same number of out arguments" ; |
267 | } |
268 | return false; |
269 | } |
270 | |
271 | // make sure among the default args, they are forward compatible |
272 | for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) { |
273 | if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) { |
274 | if (why_not) { |
275 | why_not |
276 | << "'" << arguments().at(i).name() << "'" |
277 | << " is not forward compatible with the older version of the schema" ; |
278 | } |
279 | return false; |
280 | } |
281 | } |
282 | |
283 | // Validate that all new arguments provided has a default value |
284 | for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { |
285 | if (!arguments().at(i).default_value()) { |
286 | if (why_not) { |
287 | why_not |
288 | << "Function schema is not forward compatible since the new argument '" |
289 | << arguments().at(i).name() << "' of type " |
290 | << arguments().at(i).type()->str() |
291 | << " did not provide a default value." ; |
292 | } |
293 | return false; |
294 | } |
295 | |
296 | auto default_val = arguments().at(i).default_value().value(); |
297 | if (default_val.isList() || default_val.isGenericDict()) { |
298 | if (why_not) { |
299 | why_not |
300 | << "Function schema is not forward compatible since the new argument '" |
301 | << arguments().at(i).name() << "' of type " |
302 | << arguments().at(i).type()->str() << " has a container type " |
303 | << "as its default value." ; |
304 | } |
305 | return false; |
306 | } |
307 | } |
308 | |
309 | // now compare the out args |
310 | for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { |
311 | if (!arguments() |
312 | .at(i - old_out_start_idx + new_out_start_idx) |
313 | .isForwardCompatibleWith(old.arguments().at(i))) { |
314 | if (why_not) { |
315 | why_not << "Out argument '" |
316 | << "'" << arguments().at(i).name() |
317 | << " is not FC with the older version of the schema" ; |
318 | } |
319 | return false; |
320 | } |
321 | } |
322 | |
323 | return true; |
324 | } |
325 | |
326 | template<typename T> |
327 | inline void FunctionSchema::checkArg( |
328 | const IValue& value, |
329 | const Argument& argument, |
330 | optional<size_t> pos) const { |
331 | if (value.isTensor() && argument.type() == TensorType::get()) { |
332 | // Fast-path for the common case |
333 | return; |
334 | } |
335 | if (!value.type<T>()->isSubtypeOf(*argument.type())) { |
336 | TORCH_CHECK( |
337 | false, |
338 | formatTypeMismatchMsg( |
339 | argument, value.type<T>()->repr_str(), pos)); |
340 | } |
341 | } |
342 | |
343 | inline std::string FunctionSchema::findErrorInKwargs(const std::vector<std::string>& kwargs) const { |
344 | // First check if any of the kwargs are unknown, i.e. don't match the name of |
345 | // any argument in the schema. |
346 | for (const auto& kwarg : kwargs) { |
347 | if (!std::count_if( |
348 | arguments().begin(), |
349 | arguments().end(), |
350 | [&kwarg](const Argument& argument) { |
351 | return argument.name() == kwarg; |
352 | })) { |
353 | return c10::str( |
354 | "Unknown keyword argument '" , |
355 | kwarg, |
356 | "' for operator '" , |
357 | name(), |
358 | "'. Schema: " , |
359 | *this); |
360 | } |
361 | } |
362 | // If there are unconsumed kwargs but none of them were unknown, the first |
363 | // positional argument present in the kwargs is duplicated. |
364 | for (const auto& argument : arguments()) { |
365 | if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) { |
366 | AT_ASSERT(!argument.default_value()); |
367 | return c10::str( |
368 | "Argument '" , |
369 | argument.name(), |
370 | "' specified both as positional and " , |
371 | "keyword argument. Schema: " , |
372 | *this); |
373 | } |
374 | } |
375 | return "" ; |
376 | } |
377 | |
378 | template <typename T> |
379 | inline void FunctionSchema::checkAndNormalizeInputs( |
380 | std::vector<IValue>& inputs, |
381 | const std::unordered_map<std::string, IValue>& kwargs) const { |
382 | // Do we have more inputs than the schema accepts? |
383 | TORCH_CHECK( |
384 | inputs.size() <= arguments().size(), |
385 | "Expected at most " , |
386 | arguments().size(), |
387 | " argument(s) for operator '" , |
388 | name(), |
389 | "', but received " , |
390 | inputs.size(), |
391 | " argument(s). Declaration: " , |
392 | *this); |
393 | |
394 | size_t consumed_kwargs = 0; |
395 | for (const auto pos : c10::irange(arguments().size())) { |
396 | const auto& argument = arguments()[pos]; |
397 | if (pos < inputs.size()) { |
398 | checkArg<T>(inputs[pos], argument, pos); |
399 | continue; |
400 | } |
401 | auto it = kwargs.find(argument.name()); |
402 | if (it != kwargs.end()) { |
403 | checkArg<T>(it->second, argument, nullopt); |
404 | inputs.push_back(it->second); |
405 | consumed_kwargs++; |
406 | continue; |
407 | } |
408 | if (argument.default_value()) { |
409 | inputs.push_back(*argument.default_value()); |
410 | continue; |
411 | } |
412 | AT_ERROR( |
413 | name(), |
414 | "() is missing value for argument '" , |
415 | argument.name(), |
416 | "'. Declaration: " , |
417 | *this); |
418 | } |
419 | if (consumed_kwargs != kwargs.size()) { |
420 | std::vector<std::string> names; |
421 | names.reserve(kwargs.size()); |
422 | for(const auto& k : kwargs) { |
423 | names.emplace_back(k.first); |
424 | } |
425 | throw std::runtime_error(findErrorInKwargs(names)); |
426 | } |
427 | } |
428 | |
429 | inline FunctionSchema FunctionSchema::cloneWithRemappedTypes( |
430 | const std::function<TypePtr(TypePtr)> type_map) const { |
431 | auto update_args = [&](const std::vector<Argument>& args) { |
432 | std::vector<Argument> new_args; |
433 | new_args.reserve(args.size()); |
434 | for(const Argument& arg : args) { |
435 | new_args.emplace_back(arg.cloneWithType(type_map(arg.type()))); |
436 | } |
437 | return new_args; |
438 | }; |
439 | return FunctionSchema( |
440 | name(), |
441 | overload_name(), |
442 | update_args(arguments()), |
443 | update_args(returns()), |
444 | is_vararg(), |
445 | is_varret()); |
446 | } |
447 | |
448 | // covariant subtyping of list of Arguments |
449 | inline bool isSubtypeOfList( |
450 | ArrayRef<Argument> child, |
451 | ArrayRef<Argument> parent, |
452 | std::ostream* why_not) { |
453 | if (child.size() != parent.size()) { |
454 | return false; |
455 | } |
456 | for (const auto i : c10::irange(child.size())) { |
457 | const Argument& c = child[i]; |
458 | const Argument& p = parent[i]; |
459 | if (c.name() != p.name()) { |
460 | return false; |
461 | } |
462 | if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) { |
463 | return false; |
464 | } |
465 | } |
466 | return true; |
467 | } |
468 | |
469 | inline bool FunctionSchema::isSubtypeOf( |
470 | const FunctionSchema& rhs, |
471 | bool as_method, |
472 | std::ostream* why_not) const { |
473 | size_t start = as_method ? 1 : 0; |
474 | // functions are contravariant in arguments but covariant in returns |
475 | return isSubtypeOfList( |
476 | ArrayRef<Argument>(rhs.arguments()).slice(start), |
477 | ArrayRef<Argument>(arguments()).slice(start), |
478 | why_not) && |
479 | isSubtypeOfList(returns(), rhs.returns(), why_not); |
480 | } |
481 | |
482 | } // namespace c10 |
483 | |