1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file lower_device_storage_access.cc
22 * \brief Lower the special device storage access.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/target/target_info.h>
27#include <tvm/tir/buffer.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/transform.h>
31
32#include "../../runtime/thread_storage_scope.h"
33#include "ir_utils.h"
34
35namespace tvm {
36namespace tir {
37
38using runtime::StorageRank;
39using runtime::StorageScope;
40
41class StorageAccessInfoLower : public StmtExprMutator {
42 public:
43 Stmt VisitStmt_(const AllocateNode* op) final {
44 auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
45 if (scope.tag.length() != 0 && scope.tag != ".dyn") {
46 auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
47 ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string();
48 ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end())
49 << "Double allocation of " << scope.to_string();
50 storage_info_[op->buffer_var.get()] = info;
51
52 // Lower allocate to device allocate when needed.
53 Stmt stmt = StmtExprMutator::VisitStmt_(op);
54 op = stmt.as<AllocateNode>();
55 if (info->head_address.defined()) {
56 return LetStmt(op->buffer_var, info->head_address, op->body);
57 } else {
58 return op->body;
59 }
60 } else {
61 return StmtExprMutator::VisitStmt_(op);
62 }
63 }
64
65 PrimExpr VisitExpr_(const CallNode* op) final {
66 if (op->op.same_as(builtin::tvm_access_ptr())) {
67 return MakeAccessPtr(op);
68 } else {
69 return StmtExprMutator::VisitExpr_(op);
70 }
71 }
72
73 private:
74 // tvm_access_ptr
75 PrimExpr MakeAccessPtr(const CallNode* op) {
76 // Specially handle the buffer packed intrinsic
77 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
78 op = expr.as<CallNode>();
79 ICHECK_EQ(op->args.size(), 5U);
80 DataType dtype = op->args[0].dtype();
81 const VarNode* buffer = op->args[1].as<VarNode>();
82 Var buffer_var = Downcast<Var>(op->args[1]);
83 PrimExpr offset = op->args[2];
84 auto it = storage_info_.find(buffer);
85 if (it != storage_info_.end() && it->second.defined()) {
86 return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second);
87 }
88 ICHECK(op->dtype.is_handle());
89 // Change to address_of
90 return AddressOffset(buffer_var, dtype, offset);
91 }
92
93 PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset,
94 const MemoryInfo& info) {
95 if (ptr_type.is_handle()) {
96 ICHECK(info->head_address.defined()) << buffer_var << " is not adddressable.";
97 return AddressOffset(buffer_var, dtype, offset);
98 }
99 int dtype_bits = dtype.bits() * dtype.lanes();
100 ICHECK_EQ(info->unit_bits % dtype_bits, 0);
101 return cast(ptr_type, analyzer_.Simplify(
102 offset / make_const(offset.dtype(), info->unit_bits / dtype_bits)));
103 }
104 // The storage scope of each buffer
105 std::unordered_map<const VarNode*, MemoryInfo> storage_info_;
106 // analyzer
107 arith::Analyzer analyzer_;
108};
109
110Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); }
111
112namespace transform {
113
114Pass LowerDeviceStorageAccessInfo() {
115 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
116 auto* n = f.CopyOnWrite();
117 n->body = StorageAccessInfoLower()(std::move(n->body));
118 return f;
119 };
120 return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
121}
122
123TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo")
124 .set_body_typed(LowerDeviceStorageAccessInfo);
125
126} // namespace transform
127} // namespace tir
128} // namespace tvm
129