1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file unwrap_vector_expr.cc
22 * \brief Utility for tracking currently active constraints
23 */
24
25#include "unwrap_vector_expr.h"
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/tir/analysis.h>
29#include <tvm/tir/builtin.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/expr_functor.h>
32#include <tvm/tir/op.h>
33
34#include <unordered_map>
35
36namespace tvm {
37namespace arith {
38
39using namespace tir;
40
41class Scalarizer : public ExprMutator {
42 public:
43 explicit Scalarizer(PrimExpr lane) : lane_(lane) {}
44
45 PrimExpr VisitExpr_(const RampNode* op) final { return op->base + lane_ * op->stride; }
46
47 PrimExpr VisitExpr_(const BroadcastNode* op) final { return op->value; }
48
49 PrimExpr VisitExpr_(const VarNode* op) final {
50 Var var = GetRef<Var>(op);
51
52 auto it = let_var_remap_.find(op);
53 if (it != let_var_remap_.end()) {
54 return it->second;
55 } else {
56 return ExprMutator::VisitExpr_(op);
57 }
58 }
59 PrimExpr VisitExpr_(const LetNode* op) final {
60 if (op->value.dtype().lanes() == 1) {
61 return ExprMutator::VisitExpr_(op);
62 }
63
64 auto it = let_var_remap_.find(op->var.get());
65 ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var;
66
67 Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of());
68 let_var_remap_[op->var.get()] = new_var;
69
70 PrimExpr value = this->VisitExpr(op->value);
71 PrimExpr body = this->VisitExpr(op->body);
72
73 let_var_remap_.erase(op->var.get());
74 return Let(op->var, value, body);
75 }
76
77 private:
78 // The lane to extract
79 PrimExpr lane_;
80
81 // Let binding
82 std::unordered_map<const VarNode*, Var> let_var_remap_;
83};
84
85PrimExpr UnwrapVectorExpr(const PrimExpr& vector_expr, const PrimExpr& lane) {
86 return Scalarizer(lane)(vector_expr);
87}
88
89} // namespace arith
90} // namespace tvm
91