1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Combine calls into context related function into one.
22 *
23 * \file combine_context_call.cc
24 */
25#include <tvm/node/structural_equal.h>
26#include <tvm/node/structural_hash.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt.h>
31#include <tvm/tir/stmt_functor.h>
32#include <tvm/tir/transform.h>
33
34#include <unordered_map>
35
36namespace tvm {
37namespace tir {
38
39// Calculate the statistics of packed function.
40// These information are needed during codegen.
41class ContextCallCombiner final : public StmtExprMutator {
42 public:
43 PrimExpr VisitExpr_(const CallNode* op) final {
44 if (op->op.same_as(builtin::tvm_thread_context())) {
45 ICHECK_EQ(op->args.size(), 1U);
46 PrimExpr ctx = op->args[0];
47 auto it = ctx_map_.find(ctx);
48 if (it != ctx_map_.end()) {
49 return it->second;
50 } else {
51 ICHECK(ctx.dtype().is_handle());
52 Var ctx_var("ctx_cache_", ctx.dtype());
53 ctx_map_[ctx] = ctx_var;
54 return std::move(ctx_var);
55 }
56 } else {
57 return StmtExprMutator::VisitExpr_(op);
58 }
59 }
60
61 Stmt VisitStmt_(const AttrStmtNode* op) final {
62 if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) {
63 // Map of comparison expression to variable
64 std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> temp;
65 std::swap(temp, ctx_map_);
66 Stmt stmt = StmtExprMutator::VisitStmt_(op);
67 std::swap(temp, ctx_map_);
68 return BuildContext(temp, stmt);
69 } else {
70 return StmtExprMutator::VisitStmt_(op);
71 }
72 }
73
74 Stmt VisitStmt_(const ForNode* op) final {
75 if (op->kind == ForKind::kParallel) {
76 // Map of comparison expression to variable
77 std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> temp;
78 std::swap(temp, ctx_map_);
79 Stmt stmt = StmtExprMutator::VisitStmt_(op);
80 std::swap(temp, ctx_map_);
81 return BuildContext(temp, stmt);
82 } else {
83 return StmtExprMutator::VisitStmt_(op);
84 }
85 }
86
87 Stmt Combine(Stmt stmt) { return BuildContext(ctx_map_, this->VisitStmt(stmt)); }
88
89 private:
90 static Stmt BuildContext(
91 const std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual>& cmap, Stmt body) {
92 for (const auto& kv : cmap) {
93 body = LetStmt(kv.second, kv.first, body);
94 }
95 return body;
96 }
97 // Map of comparison expression to variable
98 std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> ctx_map_;
99};
100
101namespace transform {
102
103Pass CombineContextCall() {
104 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
105 auto* n = f.CopyOnWrite();
106 n->body = ContextCallCombiner().Combine(std::move(n->body));
107 return f;
108 };
109 return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
110}
111
112TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall);
113
114} // namespace transform
115} // namespace tir
116} // namespace tvm
117