1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arithmetic/pattern_match.h |
22 | * |
23 | * \brief Internal tool for expression-template based pattern matching. |
24 | * |
25 | * It helps to simplify pattern matching and rewrites. |
26 | * All the patterns are generated via expression template during compile time, |
27 | * so the result code should be as efficient as manually written pattern match code. |
28 | * |
29 | * The code below shows how to use the pattern matcher. |
30 | * |
31 | * \code |
32 | * |
33 | * // max(x + z, y + z) => max(x, y) + z |
34 | * arith::PVar<Expr> x, y, z; |
35 | * |
36 | * // The following code tries to match the declared pattern. |
37 | * // Match will fill the result of match into PVar if successful. |
38 | * // Note that z occurs twice in the pattern, |
39 | * // an equality check is performed to ensure each occurance of z |
40 | * // is equivalent to each other. |
41 | * if (max(x + z, y + z).Match(expr)) { |
42 | * // Eval evaluates a pattern with the current matched value. |
43 | * // The filled value is valid until the next call to Match. |
44 | * return (max(x, y) + z).Eval(); |
45 | * } |
46 | * |
47 | * tvm::tir::Var tx, ty; |
48 | * arith::PVar<IntImm> c; |
49 | * arith::PVar<Var> v; |
50 | * // We can match integer and Var, both of which are |
51 | * // special case container of Expr |
52 | * ICHECK((v * c).Match(tx * 3)); |
53 | * ICHECK_EQ(c.Eval()->value, 3); |
54 | * // cannot match c to ty |
55 | * ICHECK(!(v * c).Match(tx * ty)); |
56 | * |
57 | * \endcode |
58 | * |
59 | * \note The pattern matcher is not threadsafe, |
60 | * do not use the same PVar in multiple threads. |
61 | * |
62 | * Please be aware that the filled value in a PVar |
63 | * can be overriden in the next call to Match. |
64 | */ |
65 | #ifndef TVM_ARITH_PATTERN_MATCH_H_ |
66 | #define TVM_ARITH_PATTERN_MATCH_H_ |
67 | |
68 | #include <tvm/tir/analysis.h> |
69 | #include <tvm/tir/builtin.h> |
70 | #include <tvm/tir/expr.h> |
71 | |
72 | #include <cmath> |
73 | #include <tuple> |
74 | |
75 | #include "const_fold.h" |
76 | |
77 | namespace tvm { |
78 | namespace arith { |
79 | /*! |
80 | * \brief Base class of all the patterns. |
81 | * |
82 | * There are two major member functions supported by each pattern. |
83 | * - Match: checks if value matches the pattern. |
84 | * - Eval: construct a new value based on matched values in PVar. |
85 | * |
86 | * We use curiously recurring template pattern to construct |
87 | * expression templates. |
88 | * |
89 | * \tparam Derived The type of the derived class. |
90 | */ |
91 | template <typename Derived> |
92 | class Pattern { |
93 | public: |
94 | /*! |
95 | * \brief Nested storage type in the expression. |
96 | * |
97 | * Depending on the Derived class, |
98 | * Nested can be Derived (nest by value) or |
99 | * const Derived& (nest by reference). |
100 | * |
101 | * The trick of Nested typedef originates from Eigen. |
102 | * |
103 | * \note We use nest by value for intermediate expressions, |
104 | * and nest by reference for PVars. |
105 | */ |
106 | using Nested = Derived; |
107 | /*! |
108 | * \brief Check if value matches the current pattern. |
109 | * |
110 | * This call also populates the PVars with matched value. |
111 | * The values in PVars are valid until the next call to Match. |
112 | * |
113 | * \return whether value matches the pattern. |
114 | */ |
115 | template <typename NodeType> |
116 | bool Match(const NodeType& value) const { |
117 | derived().InitMatch_(); |
118 | return derived().Match_(value); |
119 | } |
120 | /*! \return Derived instance of current class. */ |
121 | const Derived& derived() const { return *static_cast<const Derived*>(this); } |
122 | }; |
123 | |
124 | /*! |
125 | * \brief Default deep equality checker |
126 | * \tparam T the comparison point. |
127 | */ |
128 | template <typename T> |
129 | class PEqualChecker { |
130 | public: |
131 | bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; } |
132 | }; |
133 | |
134 | template <> |
135 | class PEqualChecker<PrimExpr> { |
136 | public: |
137 | bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { |
138 | if (lhs.same_as(rhs)) return true; |
139 | return tir::ExprDeepEqual()(lhs, rhs); |
140 | } |
141 | }; |
142 | |
143 | template <> |
144 | class PEqualChecker<IntImm> { |
145 | public: |
146 | bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } |
147 | }; |
148 | |
149 | template <> |
150 | class PEqualChecker<FloatImm> { |
151 | public: |
152 | bool operator()(const FloatImm& lhs, const FloatImm& rhs) const { |
153 | return std::fabs(lhs->value - rhs->value) < 1e-20; |
154 | } |
155 | }; |
156 | |
157 | template <> |
158 | class PEqualChecker<tir::Var> { |
159 | public: |
160 | bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); } |
161 | }; |
162 | |
163 | /*! |
164 | * \brief Pattern variable container. |
165 | * |
166 | * PVar is used as a "hole" in the pattern that can be matched. |
167 | * |
168 | * \tparam T the type of the hole. |
169 | * |
170 | * \note PVar is not thread safe. |
171 | * Do not use the same PVar in multiple threads. |
172 | */ |
173 | template <typename T> |
174 | class PVar : public Pattern<PVar<T>> { |
175 | public: |
176 | // Store PVars by reference in the expression. |
177 | using Nested = const PVar<T>&; |
178 | |
179 | void InitMatch_() const { filled_ = false; } |
180 | |
181 | bool Match_(const T& value) const { |
182 | if (!filled_) { |
183 | value_ = value; |
184 | filled_ = true; |
185 | return true; |
186 | } else { |
187 | return PEqualChecker<T>()(value_, value); |
188 | } |
189 | } |
190 | |
191 | template <typename NodeRefType, |
192 | typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type> |
193 | bool Match_(const NodeRefType& value) const { |
194 | if (const auto* ptr = value.template as<typename T::ContainerType>()) { |
195 | return Match_(GetRef<T>(ptr)); |
196 | } else { |
197 | return false; |
198 | } |
199 | } |
200 | |
201 | T Eval() const { |
202 | ICHECK(filled_); |
203 | return value_; |
204 | } |
205 | |
206 | T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; } |
207 | |
208 | protected: |
209 | /*! \brief The matched value */ |
210 | mutable T value_; |
211 | /*! \brief whether the variable has been filled */ |
212 | mutable bool filled_{false}; |
213 | }; |
214 | |
215 | /*! |
216 | * \brief Wrapper for pattern variable container with extra match logic. |
217 | * |
218 | * \tparam Derived the type of derived class. |
219 | * \tparam T the type of the hole. |
220 | */ |
221 | template <typename Derived, typename T> |
222 | class PVarWithCheck : public arith::Pattern<PVarWithCheck<Derived, T>> { |
223 | public: |
224 | // Store by reference in the expression. |
225 | using Nested = const PVarWithCheck<Derived, T>&; |
226 | |
227 | void InitMatch_() const { pvar_.InitMatch_(); } |
228 | |
229 | bool Match_(const T& value) const { |
230 | if (!static_cast<const Derived*>(this)->Match_(value)) return false; |
231 | return pvar_.Match_(value); |
232 | } |
233 | |
234 | template <typename NodeRefType, |
235 | typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type> |
236 | bool Match_(const NodeRefType& value) const { |
237 | if (const auto* ptr = value.template as<typename T::ContainerType>()) { |
238 | return Match_(GetRef<T>(ptr)); |
239 | } else { |
240 | return false; |
241 | } |
242 | } |
243 | |
244 | T Eval() const { return pvar_.Eval(); } |
245 | |
246 | protected: |
247 | arith::PVar<T> pvar_; |
248 | }; |
249 | |
250 | /*! |
251 | * \brief Pattern variable container with expr type check. |
252 | * |
253 | * \tparam T the type of the hole. |
254 | * \tparam DType the Pattern type of dtype. |
255 | */ |
256 | template <typename T, typename DType, |
257 | typename = std::enable_if<std::is_base_of<T, PrimExpr>::value>> |
258 | class PVarWithDataType : public PVarWithCheck<PVarWithDataType<T, DType>, T> { |
259 | public: |
260 | explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {} |
261 | |
262 | bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } |
263 | |
264 | protected: |
265 | typename DType::Nested dtype_; |
266 | }; |
267 | |
268 | /*! |
269 | * \brief Pattern variable container for data type with lanes. |
270 | */ |
271 | class PVecDataType : public PVarWithCheck<PVecDataType, DataType> { |
272 | public: |
273 | /*! \brief construct vector dtype placeholder with element type check */ |
274 | explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} |
275 | |
276 | bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); } |
277 | |
278 | protected: |
279 | DataType elem_dtype_; |
280 | }; |
281 | |
282 | /*! |
283 | * \brief Constant Pattern variable container. |
284 | * |
285 | * \tparam T the type of the hole. |
286 | */ |
287 | template <typename T> |
288 | class PConst : public Pattern<PConst<T>> { |
289 | public: |
290 | PConst(T value) // NOLINT(*) |
291 | : value_(value) {} |
292 | |
293 | void InitMatch_() const {} |
294 | |
295 | bool Match_(const T& value) const { return PEqualChecker<T>()(value_, value); } |
296 | |
297 | T Eval() const { return value_; } |
298 | |
299 | private: |
300 | const T value_; |
301 | }; |
302 | |
303 | /*! |
304 | * \brief Pattern binary expression. |
305 | * \tparam OpType The AST noderef type. |
306 | * \tparam TA The pattern type of the first operand. |
307 | * \tparam TB The pattern type of the second operand. |
308 | */ |
309 | template <typename OpType, typename TA, typename TB> |
310 | class PBinaryExpr : public Pattern<PBinaryExpr<OpType, TA, TB>> { |
311 | public: |
312 | PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {} |
313 | |
314 | void InitMatch_() const { |
315 | a_.InitMatch_(); |
316 | b_.InitMatch_(); |
317 | } |
318 | |
319 | bool Match_(const ObjectRef& node) const { |
320 | using NodeType = typename OpType::ContainerType; |
321 | if (const NodeType* ptr = node.as<NodeType>()) { |
322 | if (!a_.Match_(ptr->a)) return false; |
323 | if (!b_.Match_(ptr->b)) return false; |
324 | return true; |
325 | } else { |
326 | return false; |
327 | } |
328 | } |
329 | |
330 | PrimExpr Eval() const { |
331 | PrimExpr lhs = a_.Eval(); |
332 | PrimExpr rhs = b_.Eval(); |
333 | if (auto ret = TryConstFold<OpType>(lhs, rhs)) return ret.value(); |
334 | return OpType(lhs, rhs); |
335 | } |
336 | |
337 | private: |
338 | typename TA::Nested a_; |
339 | typename TB::Nested b_; |
340 | }; |
341 | |
342 | template <typename TA> |
343 | class PConstWithTypeLike : public Pattern<PConstWithTypeLike<TA>> { |
344 | public: |
345 | PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {} |
346 | |
347 | void InitMatch_() const {} |
348 | |
349 | bool Match_(const ObjectRef& node) const { |
350 | if (const tir::IntImmNode* ptr = node.as<tir::IntImmNode>()) { |
351 | return ptr->value == value_; |
352 | } else { |
353 | return false; |
354 | } |
355 | } |
356 | |
357 | PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); } |
358 | |
359 | private: |
360 | typename TA::Nested ref_; |
361 | int64_t value_; |
362 | }; |
363 | |
364 | #define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ |
365 | template <typename TA, typename TB> \ |
366 | inline PBinaryExpr<NodeName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \ |
367 | CheckStep; \ |
368 | return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \ |
369 | } \ |
370 | template <typename TA> \ |
371 | inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA>> FuncName(const Pattern<TA>& a, \ |
372 | int64_t b) { \ |
373 | CheckStep; \ |
374 | return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \ |
375 | } \ |
376 | template <typename TA> \ |
377 | inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> FuncName(int64_t b, \ |
378 | const Pattern<TA>& a) { \ |
379 | CheckStep; \ |
380 | return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \ |
381 | } |
382 | |
383 | #define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) |
384 | |
385 | // raise ambiguity error for operator overload of / and % |
386 | TVM_PATTERN_BINARY_OP_EX(operator/, tir::Div, DivAmbiguityError(a)); |
387 | TVM_PATTERN_BINARY_OP_EX(operator%, tir::Mod, DivAmbiguityError(a)); |
388 | |
389 | // arithmetic expressions |
390 | TVM_PATTERN_BINARY_OP(operator+, tir::Add); |
391 | TVM_PATTERN_BINARY_OP(operator-, tir::Sub); |
392 | TVM_PATTERN_BINARY_OP(operator*, tir::Mul); |
393 | TVM_PATTERN_BINARY_OP(min, tir::Min); |
394 | TVM_PATTERN_BINARY_OP(max, tir::Max); |
395 | TVM_PATTERN_BINARY_OP(div, tir::Div); |
396 | TVM_PATTERN_BINARY_OP(truncdiv, tir::Div); |
397 | TVM_PATTERN_BINARY_OP(truncmod, tir::Mod); |
398 | TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDiv); |
399 | TVM_PATTERN_BINARY_OP(floormod, tir::FloorMod); |
400 | |
401 | // logical expressions |
402 | TVM_PATTERN_BINARY_OP(operator>, tir::GT); |
403 | TVM_PATTERN_BINARY_OP(operator>=, tir::GE); |
404 | TVM_PATTERN_BINARY_OP(operator<, tir::LT); |
405 | TVM_PATTERN_BINARY_OP(operator<=, tir::LE); |
406 | TVM_PATTERN_BINARY_OP(operator==, tir::EQ); |
407 | TVM_PATTERN_BINARY_OP(operator!=, tir::NE); |
408 | TVM_PATTERN_BINARY_OP(operator&&, tir::And); |
409 | TVM_PATTERN_BINARY_OP(operator||, tir::Or); |
410 | |
411 | /*! |
412 | * \brief Pattern not expression. |
413 | * \tparam TA The pattern type of the true operand. |
414 | */ |
415 | template <typename TA> |
416 | class PNotExpr : public Pattern<PNotExpr<TA>> { |
417 | public: |
418 | explicit PNotExpr(const TA& value) : value_(value) {} |
419 | |
420 | void InitMatch_() const { value_.InitMatch_(); } |
421 | |
422 | bool Match_(const ObjectRef& node) const { |
423 | if (const tir::NotNode* ptr = node.as<tir::NotNode>()) { |
424 | if (!value_.Match_(ptr->a)) return false; |
425 | return true; |
426 | } else { |
427 | return false; |
428 | } |
429 | } |
430 | |
431 | PrimExpr Eval() const { return tir::Not(value_.Eval()); } |
432 | |
433 | private: |
434 | typename TA::Nested value_; |
435 | }; |
436 | |
437 | template <typename TA> |
438 | inline PNotExpr<TA> operator!(const Pattern<TA>& value) { |
439 | return PNotExpr<TA>(value.derived()); |
440 | } |
441 | |
442 | // select |
443 | /*! |
444 | * \brief Pattern select expression. |
445 | * \tparam TCond The pattern type of the condition. |
446 | * \tparam TA The pattern type of the true operand. |
447 | * \tparam TB The pattern type of the false operand. |
448 | */ |
449 | template <typename TCond, typename TA, typename TB> |
450 | class PSelectExpr : public Pattern<PSelectExpr<TCond, TA, TB>> { |
451 | public: |
452 | PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value) |
453 | : condition_(condition), true_value_(true_value), false_value_(false_value) {} |
454 | |
455 | void InitMatch_() const { |
456 | condition_.InitMatch_(); |
457 | true_value_.InitMatch_(); |
458 | false_value_.InitMatch_(); |
459 | } |
460 | |
461 | bool Match_(const ObjectRef& node) const { |
462 | if (const tir::SelectNode* ptr = node.as<tir::SelectNode>()) { |
463 | if (!condition_.Match_(ptr->condition)) return false; |
464 | if (!true_value_.Match_(ptr->true_value)) return false; |
465 | if (!false_value_.Match_(ptr->false_value)) return false; |
466 | return true; |
467 | } else { |
468 | return false; |
469 | } |
470 | } |
471 | |
472 | PrimExpr Eval() const { |
473 | return tir::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); |
474 | } |
475 | |
476 | private: |
477 | typename TCond::Nested condition_; |
478 | typename TA::Nested true_value_; |
479 | typename TB::Nested false_value_; |
480 | }; |
481 | |
482 | /*! |
483 | * \brief Construct a select pattern. |
484 | * |
485 | * \param condition The condition expression. |
486 | * \param true_value The value when condition is true. |
487 | * \param true_value The value when condition is false. |
488 | * |
489 | * \return The result pattern. |
490 | * |
491 | * \tparam TCond The pattern type of the condition. |
492 | * \tparam TA The pattern type of the true operand. |
493 | * \tparam TB The pattern type of the false operand. |
494 | */ |
495 | template <typename TCond, typename TA, typename TB> |
496 | inline PSelectExpr<TCond, TA, TB> select(const Pattern<TCond>& condition, |
497 | const Pattern<TA>& true_value, |
498 | const Pattern<TB>& false_value) { |
499 | return PSelectExpr<TCond, TA, TB>(condition.derived(), true_value.derived(), |
500 | false_value.derived()); |
501 | } |
502 | |
503 | /*! |
504 | * \brief Pattern cast expression. |
505 | * \tparam DType The Pattern type of dtype. |
506 | * \tparam TA The pattern type of the first operand. |
507 | */ |
508 | template <typename DType, typename TA> |
509 | class PCastExpr : public Pattern<PCastExpr<DType, TA>> { |
510 | public: |
511 | PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {} |
512 | |
513 | void InitMatch_() const { |
514 | dtype_.InitMatch_(); |
515 | value_.InitMatch_(); |
516 | } |
517 | |
518 | bool Match_(const ObjectRef& node) const { |
519 | if (const tir::CastNode* ptr = node.as<tir::CastNode>()) { |
520 | if (!dtype_.Match_(ptr->dtype)) return false; |
521 | if (!value_.Match_(ptr->value)) return false; |
522 | return true; |
523 | } else { |
524 | return false; |
525 | } |
526 | } |
527 | |
528 | PrimExpr Eval() const { return tir::Cast(dtype_.Eval(), value_.Eval()); } |
529 | |
530 | private: |
531 | typename DType::Nested dtype_; |
532 | typename TA::Nested value_; |
533 | }; |
534 | |
535 | /*! |
536 | * \brief Construct a cast pattern. |
537 | * |
538 | * \param dtype The target data type, can be PVar<DataType> or PConst<DataType>. |
539 | * \param value The input type. |
540 | * |
541 | * \return The result pattern. |
542 | * |
543 | * \tparam DType The pattern type of type. |
544 | * \tparam TA The pattern type of value. |
545 | */ |
546 | template <typename DType, typename TA> |
547 | inline PCastExpr<DType, TA> cast(const Pattern<DType>& dtype, const Pattern<TA>& value) { |
548 | return PCastExpr<DType, TA>(dtype.derived(), value.derived()); |
549 | } |
550 | |
551 | /*! |
552 | * \brief Pattern ramp expression. |
553 | * \tparam TBase The pattern type of the base. |
554 | * \tparam TStride The pattern type of the stride. |
555 | * \tparam TLanes The pattern type of the lanes. |
556 | */ |
557 | template <typename TBase, typename TStride, typename TLanes> |
558 | class PRampExpr : public Pattern<PRampExpr<TBase, TStride, TLanes>> { |
559 | public: |
560 | PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes) |
561 | : base_(base), stride_(stride), lanes_(lanes) {} |
562 | |
563 | void InitMatch_() const { |
564 | base_.InitMatch_(); |
565 | stride_.InitMatch_(); |
566 | lanes_.InitMatch_(); |
567 | } |
568 | |
569 | bool Match_(const ObjectRef& node) const { |
570 | if (const tir::RampNode* ptr = node.as<tir::RampNode>()) { |
571 | if (!base_.Match_(ptr->base)) return false; |
572 | if (!stride_.Match_(ptr->stride)) return false; |
573 | if (!lanes_.Match_(ptr->lanes)) return false; |
574 | return true; |
575 | } else { |
576 | return false; |
577 | } |
578 | } |
579 | |
580 | PrimExpr Eval() const { return tir::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); } |
581 | |
582 | private: |
583 | typename TBase::Nested base_; |
584 | typename TStride::Nested stride_; |
585 | typename TLanes::Nested lanes_; |
586 | }; |
587 | |
588 | /*! |
589 | * \brief Construct a ramp pattern. |
590 | * |
591 | * \param base The base pattern. |
592 | * \param stride The stride pattern. |
593 | * \param lanes The lanes pattern. |
594 | * |
595 | * \return The result pattern. |
596 | * |
597 | * \tparam TBase The pattern type of the base. |
598 | * \tparam TStride The pattern type of the stride. |
599 | * \tparam TLanes The pattern type of the lanes. |
600 | */ |
601 | template <typename TBase, typename TStride, typename TLanes> |
602 | inline PRampExpr<TBase, TStride, TLanes> ramp(const Pattern<TBase>& base, |
603 | const Pattern<TStride>& stride, |
604 | const Pattern<TLanes>& lanes) { |
605 | return PRampExpr<TBase, TStride, TLanes>(base.derived(), stride.derived(), lanes.derived()); |
606 | } |
607 | |
608 | template <typename TBase> |
609 | inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>> ramp(const Pattern<TBase>& base, |
610 | int stride, int lanes) { |
611 | return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>( |
612 | base.derived(), PConstWithTypeLike<TBase>(base.derived(), stride), PConst<int>(lanes)); |
613 | } |
614 | |
615 | /*! |
616 | * \brief Pattern broadcast expression. |
617 | * \tparam TA The pattern type of the value. |
618 | * \tparam TLanes The pattern type of the lanes. |
619 | */ |
620 | template <typename TA, typename TLanes> |
621 | class PBroadcastExpr : public Pattern<PBroadcastExpr<TA, TLanes>> { |
622 | public: |
623 | PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {} |
624 | |
625 | void InitMatch_() const { |
626 | value_.InitMatch_(); |
627 | lanes_.InitMatch_(); |
628 | } |
629 | |
630 | bool Match_(const ObjectRef& node) const { |
631 | if (const tir::BroadcastNode* ptr = node.as<tir::BroadcastNode>()) { |
632 | if (!value_.Match_(ptr->value)) return false; |
633 | if (!lanes_.Match_(ptr->lanes)) return false; |
634 | return true; |
635 | } else { |
636 | return false; |
637 | } |
638 | } |
639 | |
640 | PrimExpr Eval() const { return tir::Broadcast(value_.Eval(), lanes_.Eval()); } |
641 | |
642 | private: |
643 | typename TA::Nested value_; |
644 | typename TLanes::Nested lanes_; |
645 | }; |
646 | |
647 | /*! |
648 | * \brief Construct a broadcast pattern. |
649 | * |
650 | * \param value The value pattern. |
651 | * \param lanes The lanes pattern. |
652 | * |
653 | * \return The result pattern. |
654 | * |
655 | * \tparam TA The pattern type of the value. |
656 | * \tparam TLanes The pattern type of the lanes. |
657 | */ |
658 | template <typename TA, typename TLanes> |
659 | inline PBroadcastExpr<TA, TLanes> broadcast(const Pattern<TA>& value, |
660 | const Pattern<TLanes>& lanes) { |
661 | return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived()); |
662 | } |
663 | |
664 | // internal namespace |
665 | namespace detail { |
666 | // implementation details for CallExpr |
667 | template <bool stop, std::size_t I, typename F> |
668 | struct tuple_for_each_dispatcher { |
669 | template <typename TTuple> |
670 | static void run(F& f, const TTuple& tuple) { // NOLINT(*) |
671 | f(I, std::get<I>(tuple)); |
672 | tuple_for_each_dispatcher<(I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>::run(f, tuple); |
673 | } |
674 | }; |
675 | |
676 | template <std::size_t I, typename F> |
677 | struct tuple_for_each_dispatcher<true, I, F> { |
678 | template <typename TTuple> |
679 | static void run(F& f, const TTuple& tuple) {} // NOLINT(*) |
680 | }; |
681 | |
682 | template <typename F, typename TTuple> |
683 | inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*) |
684 | tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>::run(f, tuple); |
685 | } |
686 | |
687 | struct PCallExprInitMatchFunctor { |
688 | template <typename T> |
689 | void operator()(size_t i, const T& pattern) const { |
690 | pattern.InitMatch_(); |
691 | } |
692 | }; |
693 | |
694 | struct PCallExprMatchFunctor { |
695 | const tir::CallNode* call_; |
696 | bool matched_{true}; |
697 | |
698 | explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {} |
699 | |
700 | template <typename T> |
701 | void operator()(size_t i, const T& pattern) { |
702 | matched_ = matched_ && pattern.Match_(call_->args[i]); |
703 | } |
704 | }; |
705 | |
706 | struct PCallExprEvalArgsFunctor { |
707 | Array<PrimExpr> args_; |
708 | |
709 | template <typename T> |
710 | void operator()(size_t i, const T& pattern) { |
711 | args_.push_back(pattern.Eval()); |
712 | } |
713 | }; |
714 | } // namespace detail |
715 | |
716 | /*! |
717 | * \brief Pattern CallExpr expression. |
718 | * \tparam Op The operator functor class. |
719 | * \tparam TArgs The arguments. |
720 | * \note Op functor contains the name of the function and |
721 | * the implementation of Eval. |
722 | */ |
723 | template <typename Op, typename... TArgs> |
724 | class PCallExpr : public Pattern<PCallExpr<Op, TArgs...>> { |
725 | public: |
726 | explicit PCallExpr(const TArgs&... args) : args_(args...) {} |
727 | |
728 | void InitMatch_() const { |
729 | detail::PCallExprInitMatchFunctor finit; |
730 | detail::tuple_for_each(finit, args_); |
731 | } |
732 | |
733 | bool Match_(const ObjectRef& node) const { |
734 | if (const tir::CallNode* ptr = node.as<tir::CallNode>()) { |
735 | if (ptr->args.size() != sizeof...(TArgs)) return false; |
736 | if (!ptr->op.same_as(Op::GetOp())) return false; |
737 | detail::PCallExprMatchFunctor fmatch(ptr); |
738 | detail::tuple_for_each(fmatch, args_); |
739 | return fmatch.matched_; |
740 | } else { |
741 | return false; |
742 | } |
743 | } |
744 | |
745 | PrimExpr Eval() const { |
746 | detail::PCallExprEvalArgsFunctor feval_args; |
747 | detail::tuple_for_each(feval_args, args_); |
748 | return Op::Eval(feval_args.args_); |
749 | } |
750 | |
751 | private: |
752 | std::tuple<typename TArgs::Nested...> args_; |
753 | }; |
754 | |
755 | // arithemetic intrinsics |
756 | #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ |
757 | struct OpName { \ |
758 | static PrimExpr Eval(Array<PrimExpr> args) { \ |
759 | return tir::Call(args[0].dtype(), GetOp(), args); \ |
760 | } \ |
761 | static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ |
762 | }; \ |
763 | template <typename TA, typename TB> \ |
764 | inline PCallExpr<OpName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \ |
765 | return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \ |
766 | } |
767 | |
768 | TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left); |
769 | TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right); |
770 | TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and); |
771 | TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); |
772 | TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); |
773 | |
774 | // unary intrinsics |
775 | #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ |
776 | struct OpName { \ |
777 | static PrimExpr Eval(Array<PrimExpr> args) { \ |
778 | return tir::Call(args[0].dtype(), GetOp(), args); \ |
779 | } \ |
780 | static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ |
781 | }; \ |
782 | template <typename TA> \ |
783 | inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \ |
784 | return PCallExpr<OpName, TA>(a.derived()); \ |
785 | } |
786 | |
787 | TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); |
788 | |
789 | // if_then_else |
790 | struct PIfThenElseOp { |
791 | static PrimExpr Eval(Array<PrimExpr> args) { return tir::Call(args[1].dtype(), GetOp(), args); } |
792 | static const Op& GetOp() { return tir::builtin::if_then_else(); } |
793 | }; |
794 | |
795 | /*! |
796 | * \brief Construct a if_then_else pattern. |
797 | * |
798 | * \param cond The condition expression. |
799 | * \param true_value The value when condition is true. |
800 | * \param true_value The value when condition is false. |
801 | * |
802 | * \return The result pattern. |
803 | * |
804 | * \tparam TCond The pattern type of the condition. |
805 | * \tparam TA The pattern type of the true operand. |
806 | * \tparam TB The pattern type of the false operand. |
807 | */ |
808 | template <typename TCond, typename TA, typename TB> |
809 | inline PCallExpr<PIfThenElseOp, TCond, TA, TB> if_then_else(const Pattern<TCond>& cond, |
810 | const Pattern<TA>& true_value, |
811 | const Pattern<TB>& false_value) { |
812 | return PCallExpr<PIfThenElseOp, TCond, TA, TB>(cond.derived(), true_value.derived(), |
813 | false_value.derived()); |
814 | } |
815 | |
816 | } // namespace arith |
817 | } // namespace tvm |
818 | #endif // TVM_ARITH_PATTERN_MATCH_H_ |
819 | |