1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file update_pointer_storage_scope.cc
22 * \brief A pass to update storage scopes for buffer variables.
23 */
24#include "update_pointer_storage_scope.h"
25
26#include <tvm/tir/expr.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31#include <unordered_map>
32
33#include "../../runtime/thread_storage_scope.h"
34#include "ir_utils.h"
35
36namespace tvm {
37namespace tir {
38
39Var WithStorageScope(const VarNode* buffer_var, String storage_scope) {
40 auto* ptr_type = buffer_var->type_annotation.as<PointerTypeNode>();
41 ICHECK(ptr_type) << "The provided variable is not of pointer type";
42 return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope),
43 buffer_var->span);
44}
45
46UpdatePointerStorageScope::UpdatePointerStorageScope(
47 const std::unordered_map<const VarNode*, String>& new_storage_scopes) {
48 for (auto& kv : new_storage_scopes) {
49 new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second);
50 }
51}
52
53PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) {
54 auto it = new_var_remap_.find(op);
55 if (it == new_var_remap_.end()) {
56 return GetRef<Var>(op);
57 }
58 return it->second;
59}
60
61Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) {
62 auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(op->buffer_var));
63 return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition),
64 StmtExprMutator::VisitStmt(op->body));
65}
66
67template <typename Node>
68Node UpdatePointerStorageScope::UpdateBufferAccess(Node node) {
69 auto new_buffer = GetUpdatedBuffer(node->buffer);
70 if (!new_buffer.same_as(node->buffer)) {
71 auto writer = node.CopyOnWrite();
72 writer->buffer = new_buffer;
73 }
74 return node;
75}
76
77Buffer UpdatePointerStorageScope::GetUpdatedBuffer(Buffer buf) {
78 // Use the cached buffer, if it exists.
79 auto key = buf.get();
80 auto it = new_buffer_remap_.find(key);
81 if (it != new_buffer_remap_.end()) {
82 return it->second;
83 }
84
85 // Update the buffer's var, if needed.
86 auto remapped = Downcast<Var>(StmtExprMutator::VisitExpr(buf->data));
87 if (!remapped.same_as(buf->data)) {
88 auto writer = buf.CopyOnWrite();
89 writer->data = remapped;
90 }
91
92 // Update the cache and return
93 new_buffer_remap_[key] = buf;
94 return buf;
95}
96
97PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) {
98 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
99}
100
101PrimExpr UpdatePointerStorageScope::VisitExpr_(const BufferLoadNode* op) {
102 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
103 return UpdateBufferAccess(node);
104}
105
106Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) {
107 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
108}
109
110Stmt UpdatePointerStorageScope::VisitStmt_(const BufferStoreNode* op) {
111 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
112 return UpdateBufferAccess(node);
113}
114
115} // namespace tir
116} // namespace tvm
117