1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ |
20 | #define TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ |
21 | |
22 | #include <tvm/tir/schedule/instruction.h> |
23 | #include <tvm/tir/schedule/schedule.h> |
24 | |
25 | #include <algorithm> |
26 | #include <sstream> |
27 | #include <string> |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | /*! |
35 | * \brief Register an InstructionKind using a trait class |
36 | * \param InstructionKindTraits A traits class of an InstructionKind |
37 | * |
38 | * Example: |
39 | * |
40 | * \code |
41 | * |
42 | * struct SomeInstructionKindTraits { |
43 | * static constexpr const char* kName = "name-of-the-instruction"; |
44 | * static constexpr bool kIsPure = false; |
45 | * |
46 | * // Convertible to `InstructionKindNode::FInstructionApply` |
47 | * static Array<ObjectRef> ApplyToSchedule( |
48 | * const tir::Schedule& sch, |
49 | * const Array<ObjectRef>& inputs, |
50 | * const Array<ObjectRef>& attrs, |
51 | * const Optional<ObjectRef>& decision); |
52 | * |
53 | * // Convertible to `InstructionKindNode::FInstructionAsPython` |
54 | * static String AsPython( |
55 | * const Array<String>& inputs, |
56 | * const Array<ObjectRef>& attrs, |
57 | * const Optional<ObjectRef>& decision, |
58 | * const Array<String>& outputs); |
59 | * |
60 | * // Convertible to `InstructionKindNode::FInstructionAttrsAsJSON` |
61 | * static ObjectRef AttrsAsJSON( |
62 | * const Array<ObjectRef>& attrs); |
63 | * |
64 | * // Convertible to `InstructionKindNode::FInstructionAttrsFromJSON` |
65 | * static Array<ObjectRef> AttrsFromJSON( |
66 | * const ObjectRef& attrs_record); |
67 | * }; |
68 | * |
69 | * TVM_REGISTER_INST_KIND_TRAITS(SomeInstructionKindTraits); |
70 | * |
71 | * \endcode |
72 | */ |
73 | #define TVM_REGISTER_INST_KIND_TRAITS(InstructionKindTraits) \ |
74 | TVM_REGISTER_INST_KIND(InstructionKindTraits::kName) \ |
75 | .set_is_pure(InstructionKindTraits::kIsPure) \ |
76 | .set_apply_to_schedule(InstructionKindTraits::ApplyToSchedule) \ |
77 | .set_attrs_as_json(InstructionKindTraits::AttrsAsJSON) \ |
78 | .set_attrs_from_json(InstructionKindTraits::AttrsFromJSON) \ |
79 | .set_as_python(InstructionKindTraits::AsPython) |
80 | |
81 | /*! |
82 | * \brief A helper to conveniently register an InstructionKind. When inherited in curiously |
83 | * recursive template pattern, the derived class `TTraits` only needs to define two functions on the |
84 | * unpacked inputs, and the helper handles unpacking and downcasting. See the example for more |
85 | * details. |
86 | * |
87 | * \tparam TTraits The derived class |
88 | * |
89 | * Example: |
90 | * |
91 | * \code |
92 | * |
93 | * struct SamplePerfectTileTraits : public UnpackedInstTraits<SamplePerfectTileTraits> { |
94 | * // The name of this kind of instruction |
95 | * static constexpr const char* kName = "SamplePerfectTile"; |
96 | * // A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state |
97 | * static constexpr bool kIsPure = true; |
98 | * // The number of inputs in this kind of instruction |
99 | * static constexpr size_t kNumInputs = 1; |
100 | * // The number of attributes in this kind of instruction |
101 | * static constexpr size_t kNumAttrs = 2; |
102 | * // The number of decisions in this kind of instruction (only 0 or 1 is allowed) |
103 | * static constexpr size_t kNumDecisions = 1; |
104 | * |
105 | * // Calling convention: |
106 | * // - All the arguments must be ObjectRef |
107 | * // - The 1st argument is Schedule |
108 | * // - The next `kNumInputs` arguments are input random variables |
109 | * // - The next `kNumAttrs` arguments are attributes |
110 | * // - The next argument is decision, if `kNumDecisions == 1` |
111 | * static Array<Var> UnpackedApplyToSchedule( |
112 | * Schedule sch, |
113 | * LoopRV loop_rv, |
114 | * Integer n, |
115 | * Integer max_innermost_factor, |
116 | * Optional<Array<Integer>> decision) { |
117 | * return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); |
118 | * } |
119 | * |
120 | * // Calling convention: |
121 | * // - All the arguments must be ObjectRef |
122 | * // - The 1st argument is an array containing names of output random variables |
123 | * // - The next `kNumInputs` arguments are names of input random variables |
124 | * // - The next `kNumAttrs` arguments are attributes |
125 | * // - The next argument is decision, if `kNumDecisions == 1` |
126 | * static String UnpackedAsPython( |
127 | * Array<String> outputs, |
128 | * String loop_rv, |
129 | * Integer n, |
130 | * Integer max_innermost_factor, |
131 | * Optional<Array<Integer>> decision) { |
132 | * PythonAPICall py("sample_perfect_tile"); |
133 | * py.Input("loop", loop_rv); |
134 | * py.Input("n", n->value); |
135 | * py.Input("max_innermost_factor", max_innermost_factor->value); |
136 | * py.Decision(decision); |
137 | * py.OutputList(outputs); |
138 | * return py.Str(); |
139 | * } |
140 | * |
141 | * template <typename> |
142 | * friend struct UnpackedInstTraits; |
143 | * }; |
144 | * |
145 | * TVM_REGISTER_INST_KIND(SamplePerfectTileTraits); |
146 | * \endcode |
147 | */ |
148 | template <class TTraits> |
149 | struct UnpackedInstTraits { |
150 | /*! |
151 | * \brief Unpack the arguments in the calling convention, and feed them into |
152 | * `TTraits::UnpackedApplyToSchedule` |
153 | * \sa InstructionKindNode::f_apply_to_schedule |
154 | */ |
155 | static Array<ObjectRef> ApplyToSchedule(const Schedule& sch, const Array<ObjectRef>& inputs, |
156 | const Array<ObjectRef>& attrs, |
157 | const Optional<ObjectRef>& decision); |
158 | |
159 | /*! |
160 | * \brief Unpack the arguments in the calling convention, and feed them into |
161 | * `TTraits::UnpackedAsPython` |
162 | * \sa InstructionKindNode::f_as_python |
163 | */ |
164 | static String AsPython(const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, |
165 | const Optional<ObjectRef>& decision, const Array<String>& outputs); |
166 | |
167 | /*! \brief No customized serializer by default */ |
168 | static constexpr std::nullptr_t AttrsAsJSON = nullptr; |
169 | |
170 | /*! \brief No customized deserializer by default */ |
171 | static constexpr std::nullptr_t AttrsFromJSON = nullptr; |
172 | |
173 | protected: |
174 | template <size_t index_offset> |
175 | static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, |
176 | const Array<ObjectRef>& inputs); |
177 | template <size_t index_offset> |
178 | static TVM_ALWAYS_INLINE void _SetAttrs(const runtime::TVMArgsSetter& setter, |
179 | const Array<ObjectRef>& attrs); |
180 | template <size_t index_offset> |
181 | static TVM_ALWAYS_INLINE void _SetDecision(const runtime::TVMArgsSetter& setter, |
182 | const Optional<ObjectRef>& decision); |
183 | static TVM_ALWAYS_INLINE Array<ObjectRef> _ConvertOutputs(const TVMRetValue& rv); |
184 | }; |
185 | |
186 | /*! |
187 | * \brief A helper class that constructs schedule API call in python syntax, |
188 | * which helps convert an Inst to a python statement. |
189 | * \sa InstructionKindNode::f_as_python |
190 | */ |
191 | class PythonAPICall { |
192 | public: |
193 | /*! |
194 | * \brief Constructor |
195 | * \param method_name The name of the schedule API to be called |
196 | */ |
197 | explicit PythonAPICall(String method_name) : method_name_(method_name), output_(NullOpt) {} |
198 | /*! \brief Add an integer input */ |
199 | inline void Input(String arg_name, int arg); |
200 | /*! \brief Add an integer input */ |
201 | inline void Input(String arg_name, int64_t arg); |
202 | /*! \brief Add a bool input */ |
203 | inline void Input(String arg_name, bool arg); |
204 | /*! \brief Add a double input */ |
205 | inline void Input(String arg_name, double arg); |
206 | /*! \brief Add an input random variable */ |
207 | inline void Input(String arg_name, String arg); |
208 | /*! \brief Add an input, dispatched to different implementations according to the object's type */ |
209 | inline void Input(String arg_name, ObjectRef arg); |
210 | /*! \brief Add the decision */ |
211 | inline void Decision(ObjectRef decision); |
212 | /*! |
213 | * \brief Add a single output random variable |
214 | * \param unit_array An array containing only one element |
215 | */ |
216 | inline void SingleOutput(Array<String> unit_array); |
217 | /*! \brief Add a list of output random variables */ |
218 | inline void OutputList(Array<String> outputs); |
219 | /*! \returns The schedule API call in python syntax */ |
220 | inline String Str() const; |
221 | |
222 | private: |
223 | /*! \brief Converts a TVM object to python string and print to the output stream */ |
224 | inline void AsPythonString(const ObjectRef& obj, std::ostream& os); |
225 | |
226 | private: |
227 | /*! \brief The name of the API to call */ |
228 | String method_name_; |
229 | /*! \brief The output of the instruction */ |
230 | Optional<String> output_; |
231 | /*! \brief The names of input arguments */ |
232 | std::vector<String> arg_names_; |
233 | /*! \brief The values of input arguments */ |
234 | std::vector<String> args_; |
235 | }; |
236 | |
237 | /********** implementation details **********/ |
238 | |
239 | // forward declaration |
240 | namespace details { |
241 | |
242 | template <typename... Args> |
243 | struct _ArgsPacker; |
244 | |
245 | template <> |
246 | struct _ArgsPacker<> { |
247 | static constexpr bool checked = true; |
248 | }; |
249 | |
250 | template <typename TObjectRef, typename... Args> |
251 | struct _ArgsPacker<TObjectRef, Args...> { |
252 | static constexpr bool checked = |
253 | std::is_base_of<ObjectRef, TObjectRef>::value && _ArgsPacker<Args...>::checked; |
254 | }; |
255 | |
256 | template <typename T> |
257 | struct _MethodType {}; |
258 | |
259 | template <typename TReturn, typename... Args> |
260 | struct _MethodType<TReturn(Args...)> { |
261 | using return_type = TReturn; |
262 | using argument_type = _ArgsPacker<Args...>; |
263 | }; |
264 | |
265 | template <typename T> |
266 | struct _NumArgs {}; |
267 | |
268 | template <typename TReturn, typename... Args> |
269 | struct _NumArgs<TReturn(Args...)> { |
270 | static constexpr size_t value = sizeof...(Args); |
271 | }; |
272 | |
273 | template <typename> |
274 | struct _IsTVMArray : std::false_type {}; |
275 | |
276 | template <typename T> |
277 | struct _IsTVMArray<runtime::Array<T>> : std::true_type {}; |
278 | |
279 | template <typename T> |
280 | struct _IsSingleObject |
281 | : std::integral_constant<bool, std::is_base_of<ObjectRef, T>::value && !_IsTVMArray<T>::value> { |
282 | }; |
283 | |
284 | template <class T> |
285 | using ReturnType = typename _MethodType<std::remove_cv_t<T>>::return_type; |
286 | |
287 | template <class T> |
288 | static constexpr bool ArgumentAreAllObjects = |
289 | _MethodType<std::remove_cv_t<T>>::argument_type::checked; |
290 | |
291 | template <class T> |
292 | static constexpr size_t NumArgs = _NumArgs<std::remove_cv_t<T>>::value; |
293 | |
294 | template <class T> |
295 | static constexpr int IsTVMArray = _IsTVMArray<std::remove_cv_t<T>>::value; |
296 | |
297 | template <class T> |
298 | static constexpr int IsSingleObject = _IsSingleObject<std::remove_cv_t<T>>::value; |
299 | |
300 | }; // namespace details |
301 | |
302 | template <class TTraits> |
303 | Array<ObjectRef> UnpackedInstTraits<TTraits>::ApplyToSchedule(const Schedule& sch, |
304 | const Array<ObjectRef>& inputs, |
305 | const Array<ObjectRef>& attrs, |
306 | const Optional<ObjectRef>& decision) { |
307 | using method_type = decltype(TTraits::UnpackedApplyToSchedule); |
308 | using return_type = details::ReturnType<method_type>; |
309 | static_assert(details::ArgumentAreAllObjects<method_type>, |
310 | "All arguments to `UnpackedApplyToSchedule` must be subclasses of ObjectRef" ); |
311 | constexpr size_t kNumArgs = details::NumArgs<method_type>; |
312 | constexpr size_t kNumInputs = TTraits::kNumInputs; |
313 | constexpr size_t kNumAttrs = TTraits::kNumAttrs; |
314 | constexpr size_t kNumDecisions = TTraits::kNumDecisions; |
315 | static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions, |
316 | "length of argument list mismatch" ); |
317 | TVMValue tvm_values[kNumArgs]; |
318 | int tvm_type_codes[kNumArgs]; |
319 | runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes); |
320 | setter(0, sch); |
321 | TTraits::template _SetInputs<1>(setter, inputs); |
322 | TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs); |
323 | TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision); |
324 | PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { |
325 | using runtime::detail::unpack_call; |
326 | constexpr size_t kNumArgs = details::NumArgs<method_type>; |
327 | ICHECK_EQ(args.size(), kNumArgs); |
328 | unpack_call<return_type, kNumArgs>(nullptr, TTraits::UnpackedApplyToSchedule, args, rv); |
329 | }); |
330 | TVMRetValue rv; |
331 | pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv); |
332 | return TTraits::_ConvertOutputs(rv); |
333 | } |
334 | |
335 | template <class TTraits> |
336 | String UnpackedInstTraits<TTraits>::AsPython(const Array<ObjectRef>& inputs, |
337 | const Array<ObjectRef>& attrs, |
338 | const Optional<ObjectRef>& decision, |
339 | const Array<String>& outputs) { |
340 | using method_type = decltype(TTraits::UnpackedAsPython); |
341 | using return_type = details::ReturnType<method_type>; |
342 | static_assert(details::ArgumentAreAllObjects<method_type>, |
343 | "All arguments to `UnpackedAsPython` must be subclasses of ObjectRef" ); |
344 | constexpr size_t kNumArgs = details::NumArgs<method_type>; |
345 | constexpr size_t kNumInputs = TTraits::kNumInputs; |
346 | constexpr size_t kNumAttrs = TTraits::kNumAttrs; |
347 | constexpr size_t kNumDecisions = TTraits::kNumDecisions; |
348 | static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions, |
349 | "length of argument list mismatch" ); |
350 | TVMValue tvm_values[kNumArgs]; |
351 | int tvm_type_codes[kNumArgs]; |
352 | runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes); |
353 | setter(0, outputs); |
354 | TTraits::template _SetInputs<1>(setter, inputs); |
355 | TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs); |
356 | TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision); |
357 | PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void { |
358 | using runtime::detail::unpack_call; |
359 | constexpr size_t kNumArgs = details::NumArgs<method_type>; |
360 | ICHECK_EQ(args.size(), kNumArgs); |
361 | unpack_call<return_type, kNumArgs>(nullptr, TTraits::UnpackedAsPython, args, rv); |
362 | }); |
363 | TVMRetValue rv; |
364 | pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv); |
365 | String result = rv; |
366 | return result; |
367 | } |
368 | |
369 | template <class TTraits> |
370 | template <size_t index_offset> |
371 | TVM_ALWAYS_INLINE void UnpackedInstTraits<TTraits>::_SetInputs(const runtime::TVMArgsSetter& setter, |
372 | const Array<ObjectRef>& inputs) { |
373 | constexpr size_t kNumInputs = TTraits::kNumInputs; |
374 | ICHECK_EQ(kNumInputs, inputs.size()) |
375 | << "ValueError: Incorrect kNumInputs for instruction: " << TTraits::kName; |
376 | const ObjectRef* ptr = inputs.template as<ArrayNode>()->begin(); |
377 | for (size_t i = 0; i < kNumInputs; ++i) { |
378 | setter(i + index_offset, *(ptr + i)); |
379 | } |
380 | } |
381 | |
382 | template <class TTraits> |
383 | template <size_t index_offset> |
384 | TVM_ALWAYS_INLINE void UnpackedInstTraits<TTraits>::_SetAttrs(const runtime::TVMArgsSetter& setter, |
385 | const Array<ObjectRef>& attrs) { |
386 | constexpr size_t kNumAttrs = TTraits::kNumAttrs; |
387 | ICHECK_EQ(kNumAttrs, attrs.size()) |
388 | << "ValueError: Incorrect kNumAttrs for instruction: " << TTraits::kName; |
389 | const ObjectRef* ptr = attrs.as<ArrayNode>()->begin(); |
390 | for (size_t i = 0; i < kNumAttrs; ++i) { |
391 | setter(i + index_offset, *(ptr + i)); |
392 | } |
393 | } |
394 | |
395 | template <class TTraits> |
396 | template <size_t index_offset> |
397 | TVM_ALWAYS_INLINE void UnpackedInstTraits<TTraits>::_SetDecision( |
398 | const runtime::TVMArgsSetter& setter, const Optional<ObjectRef>& decision) { |
399 | constexpr size_t kNumDecisions = TTraits::kNumDecisions; |
400 | static_assert(kNumDecisions <= 1, "an instruction is supposed to have at most 1 decision" ); |
401 | if (kNumDecisions == 1) { |
402 | setter(index_offset, decision); |
403 | } else { |
404 | ICHECK(!decision.defined()); |
405 | } |
406 | } |
407 | |
408 | template <class TTraits> |
409 | TVM_ALWAYS_INLINE Array<ObjectRef> UnpackedInstTraits<TTraits>::_ConvertOutputs( |
410 | const TVMRetValue& rv) { |
411 | using method_type = decltype(TTraits::UnpackedApplyToSchedule); |
412 | using return_type = details::ReturnType<method_type>; |
413 | constexpr int is_array = details::IsTVMArray<return_type>; |
414 | constexpr int is_single_obj = details::IsSingleObject<return_type>; |
415 | constexpr int is_void = std::is_void<return_type>::value; |
416 | static_assert(is_array || is_single_obj || is_void, "return type not supported" ); |
417 | static_assert(is_array + is_single_obj + is_void == 1, "internal template error" ); |
418 | if (is_void) { |
419 | return {}; |
420 | } else if (is_single_obj) { |
421 | ObjectRef obj = rv; |
422 | return {obj}; |
423 | } else if (is_array) { |
424 | ObjectRef obj = rv; |
425 | const ArrayNode* array = obj.as<ArrayNode>(); |
426 | return GetRef<Array<ObjectRef>>(array); |
427 | } |
428 | } |
429 | |
430 | /********** PythonAPICall **********/ |
431 | |
432 | inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) { |
433 | if (!obj.defined()) { |
434 | os << "None" ; |
435 | } else if (const auto* str = obj.as<runtime::StringObj>()) { |
436 | os << str->data; |
437 | } else if (const auto* int_imm = obj.as<IntImmNode>()) { |
438 | os << int_imm->value; |
439 | } else if (const auto* float_imm = obj.as<FloatImmNode>()) { |
440 | os.precision(17); |
441 | os << float_imm->value; |
442 | } else if (const auto* array = obj.as<ArrayNode>()) { |
443 | os << '['; |
444 | bool is_first = true; |
445 | for (const ObjectRef& e : *array) { |
446 | if (is_first) { |
447 | is_first = false; |
448 | } else { |
449 | os << ", " ; |
450 | } |
451 | AsPythonString(e, os); |
452 | } |
453 | os << ']'; |
454 | } else if (const auto* dict = obj.as<MapNode>()) { |
455 | os << '{'; |
456 | bool is_first = true; |
457 | std::vector<std::pair<std::string, std::string>> dict_items; |
458 | for (auto it = dict->begin(); it != dict->end(); ++it) { |
459 | std::ostringstream ks; |
460 | AsPythonString(it->first, ks); |
461 | std::ostringstream vs; |
462 | AsPythonString(it->second, vs); |
463 | dict_items.emplace_back(ks.str(), vs.str()); |
464 | } |
465 | std::sort(dict_items.begin(), dict_items.end(), |
466 | [](const auto& p1, const auto& p2) { return p1.first < p2.first; }); |
467 | for (const auto& kv : dict_items) { |
468 | if (is_first) { |
469 | is_first = false; |
470 | } else { |
471 | os << ", " ; |
472 | } |
473 | os << '\"' << kv.first << "\": " << kv.second; |
474 | } |
475 | os << '}'; |
476 | } else { |
477 | LOG(FATAL) << "ValueError: Cannot translate type '" << obj->GetTypeKey() |
478 | << "' to python. Its value is: " << obj; |
479 | throw; |
480 | } |
481 | } |
482 | |
483 | void PythonAPICall::Input(String arg_name, int arg) { |
484 | arg_names_.emplace_back(std::move(arg_name)); |
485 | args_.push_back(std::to_string(arg)); |
486 | } |
487 | |
488 | void PythonAPICall::Input(String arg_name, int64_t arg) { |
489 | arg_names_.emplace_back(std::move(arg_name)); |
490 | args_.push_back(std::to_string(arg)); |
491 | } |
492 | |
493 | void PythonAPICall::Input(String arg_name, bool arg) { |
494 | static const char* true_str = "True" ; |
495 | static const char* false_str = "False" ; |
496 | arg_names_.emplace_back(std::move(arg_name)); |
497 | if (arg) { |
498 | args_.push_back(true_str); |
499 | } else { |
500 | args_.push_back(false_str); |
501 | } |
502 | } |
503 | |
504 | void PythonAPICall::Input(String arg_name, double arg) { |
505 | arg_names_.emplace_back(std::move(arg_name)); |
506 | std::ostringstream os; |
507 | os.precision(17); |
508 | os << arg; |
509 | args_.push_back(os.str()); |
510 | } |
511 | |
512 | void PythonAPICall::Input(String arg_name, String arg) { |
513 | arg_names_.emplace_back(std::move(arg_name)); |
514 | args_.emplace_back(std::move(arg)); |
515 | } |
516 | |
517 | void PythonAPICall::Input(String arg_name, ObjectRef arg) { |
518 | arg_names_.emplace_back(std::move(arg_name)); |
519 | std::ostringstream os; |
520 | AsPythonString(arg, os); |
521 | args_.push_back(os.str()); |
522 | } |
523 | |
524 | void PythonAPICall::Decision(ObjectRef decision) { |
525 | if (decision.defined()) { |
526 | this->Input("decision" , decision); |
527 | } |
528 | } |
529 | |
530 | void PythonAPICall::SingleOutput(Array<String> unit_array) { |
531 | ICHECK_EQ(unit_array.size(), 1); |
532 | this->output_ = unit_array[0]; |
533 | } |
534 | |
535 | void PythonAPICall::OutputList(Array<String> outputs) { |
536 | if (outputs.empty()) { |
537 | return; |
538 | } |
539 | if (outputs.size() == 1) { |
540 | this->output_ = outputs[0] + "," ; |
541 | return; |
542 | } |
543 | std::ostringstream os; |
544 | os << outputs[0]; |
545 | for (int i = 1, n = outputs.size(); i < n; ++i) { |
546 | os << ", " << outputs[i]; |
547 | } |
548 | this->output_ = os.str(); |
549 | } |
550 | |
551 | String PythonAPICall::Str() const { |
552 | std::ostringstream os; |
553 | if (output_.defined()) { |
554 | os << output_.value() << " = " ; |
555 | } |
556 | os << "sch." << method_name_ << '('; |
557 | int n = args_.size(); |
558 | for (int i = 0; i < n; ++i) { |
559 | if (i > 0) { |
560 | os << ", " ; |
561 | } |
562 | if (arg_names_[i].empty()) { |
563 | os << args_[i]; |
564 | } else { |
565 | os << arg_names_[i] << '=' << args_[i]; |
566 | } |
567 | } |
568 | os << ')'; |
569 | return os.str(); |
570 | } |
571 | |
572 | } // namespace tir |
573 | } // namespace tvm |
574 | |
575 | #endif // TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_ |
576 | |