1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file lower_device_storage_access.cc |
22 | * \brief Lower the special device storage access. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/target/target_info.h> |
27 | #include <tvm/tir/buffer.h> |
28 | #include <tvm/tir/builtin.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | #include "../../runtime/thread_storage_scope.h" |
33 | #include "ir_utils.h" |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | using runtime::StorageRank; |
39 | using runtime::StorageScope; |
40 | |
41 | class StorageAccessInfoLower : public StmtExprMutator { |
42 | public: |
43 | Stmt VisitStmt_(const AllocateNode* op) final { |
44 | auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); |
45 | if (scope.tag.length() != 0 && scope.tag != ".dyn" ) { |
46 | auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); |
47 | ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); |
48 | ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) |
49 | << "Double allocation of " << scope.to_string(); |
50 | storage_info_[op->buffer_var.get()] = info; |
51 | |
52 | // Lower allocate to device allocate when needed. |
53 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
54 | op = stmt.as<AllocateNode>(); |
55 | if (info->head_address.defined()) { |
56 | return LetStmt(op->buffer_var, info->head_address, op->body); |
57 | } else { |
58 | return op->body; |
59 | } |
60 | } else { |
61 | return StmtExprMutator::VisitStmt_(op); |
62 | } |
63 | } |
64 | |
65 | PrimExpr VisitExpr_(const CallNode* op) final { |
66 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
67 | return MakeAccessPtr(op); |
68 | } else { |
69 | return StmtExprMutator::VisitExpr_(op); |
70 | } |
71 | } |
72 | |
73 | private: |
74 | // tvm_access_ptr |
75 | PrimExpr MakeAccessPtr(const CallNode* op) { |
76 | // Specially handle the buffer packed intrinsic |
77 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
78 | op = expr.as<CallNode>(); |
79 | ICHECK_EQ(op->args.size(), 5U); |
80 | DataType dtype = op->args[0].dtype(); |
81 | const VarNode* buffer = op->args[1].as<VarNode>(); |
82 | Var buffer_var = Downcast<Var>(op->args[1]); |
83 | PrimExpr offset = op->args[2]; |
84 | auto it = storage_info_.find(buffer); |
85 | if (it != storage_info_.end() && it->second.defined()) { |
86 | return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second); |
87 | } |
88 | ICHECK(op->dtype.is_handle()); |
89 | // Change to address_of |
90 | return AddressOffset(buffer_var, dtype, offset); |
91 | } |
92 | |
93 | PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset, |
94 | const MemoryInfo& info) { |
95 | if (ptr_type.is_handle()) { |
96 | ICHECK(info->head_address.defined()) << buffer_var << " is not adddressable." ; |
97 | return AddressOffset(buffer_var, dtype, offset); |
98 | } |
99 | int dtype_bits = dtype.bits() * dtype.lanes(); |
100 | ICHECK_EQ(info->unit_bits % dtype_bits, 0); |
101 | return cast(ptr_type, analyzer_.Simplify( |
102 | offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); |
103 | } |
104 | // The storage scope of each buffer |
105 | std::unordered_map<const VarNode*, MemoryInfo> storage_info_; |
106 | // analyzer |
107 | arith::Analyzer analyzer_; |
108 | }; |
109 | |
110 | Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } |
111 | |
112 | namespace transform { |
113 | |
114 | Pass LowerDeviceStorageAccessInfo() { |
115 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
116 | auto* n = f.CopyOnWrite(); |
117 | n->body = StorageAccessInfoLower()(std::move(n->body)); |
118 | return f; |
119 | }; |
120 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo" , {}); |
121 | } |
122 | |
123 | TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo" ) |
124 | .set_body_typed(LowerDeviceStorageAccessInfo); |
125 | |
126 | } // namespace transform |
127 | } // namespace tir |
128 | } // namespace tvm |
129 | |