1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/memory_alloc.cc |
22 | * \brief A pass for manifesting explicit memory allocations. |
23 | */ |
24 | |
25 | #include <tvm/node/structural_equal.h> |
26 | #include <tvm/node/structural_hash.h> |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/attrs/annotation.h> |
29 | #include <tvm/relay/attrs/call.h> |
30 | #include <tvm/relay/attrs/device_copy.h> |
31 | #include <tvm/relay/attrs/memory.h> |
32 | #include <tvm/relay/expr.h> |
33 | #include <tvm/relay/expr_functor.h> |
34 | #include <tvm/relay/op.h> |
35 | #include <tvm/relay/transform.h> |
36 | #include <tvm/runtime/logging.h> |
37 | #include <tvm/target/target.h> |
38 | |
39 | #include <cstdint> |
40 | #include <cstdio> |
41 | #include <string> |
42 | #include <unordered_set> |
43 | #include <vector> |
44 | |
45 | #include "../backend/te_compiler.h" |
46 | #include "../backend/te_compiler_cache.h" |
47 | #include "../op/annotation/annotation.h" |
48 | #include "../op/call/call.h" |
49 | #include "../op/memory/device_copy.h" |
50 | #include "../op/memory/memory.h" |
51 | #include "../op/vm/vm.h" |
52 | #include "./device_aware_visitors.h" |
53 | #include "./let_list.h" |
54 | #include "./pass_utils.h" |
55 | #include "./pattern_utils.h" |
56 | |
57 | using namespace tvm::runtime; |
58 | |
59 | namespace tvm { |
60 | namespace relay { |
61 | |
62 | class DialectRewriter : public transform::DeviceAwareExprMutator { |
63 | public: |
64 | DialectRewriter(IRModule mod, VirtualDevice host_virtual_device) |
65 | : transform::DeviceAwareExprMutator(mod), |
66 | mod_(std::move(mod)), |
67 | host_virtual_device_(std::move(host_virtual_device)) {} |
68 | |
69 | Function Rewrite(const Function& expr) { return Downcast<Function>(Mutate(expr)); } |
70 | |
71 | private: |
72 | using ExprMutator::VisitExpr_; |
73 | |
74 | Expr VisitExpr_(const TupleNode* tuple_node) final { |
75 | LetList& scope = scopes_.back(); |
76 | Array<Expr> new_fields; |
77 | new_fields.reserve(tuple_node->fields.size()); |
78 | |
79 | for (auto field : tuple_node->fields) { |
80 | auto new_field = Mutate(field); |
81 | if (const auto* op = new_field.as<ConstantNode>()) { |
82 | DataType dtype(op->data->dtype); |
83 | bool is_simple_const = (dtype == DataType::Int(32) || dtype == DataType::Int(64) || |
84 | dtype == DataType::Float(32) || dtype == DataType::Float(64) || |
85 | dtype == DataType::Bool()); |
86 | if (!op->is_scalar() || !is_simple_const) { |
87 | VirtualDevice virtual_device = GetVirtualDevice(field); |
88 | ICHECK(!virtual_device->IsFullyUnconstrained()); |
89 | Var const_var("const" , Type(nullptr)); |
90 | new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, virtual_device)); |
91 | } |
92 | } |
93 | new_fields.push_back(new_field); |
94 | } |
95 | return WithFields(GetRef<Tuple>(tuple_node), new_fields); |
96 | } |
97 | |
98 | void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); } |
99 | |
100 | std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final { |
101 | Expr new_value = Mutate(value); |
102 | VirtualDevice virtual_device = GetVirtualDevice(value); |
103 | ICHECK(!virtual_device->IsFullyUnconstrained()); |
104 | scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, virtual_device)); |
105 | // Since we always need a let block on which to bind sub-expressions the rewritten bindings |
106 | // are tracked in the current scopes. But return the rewritten binding anyway. |
107 | return {var, new_value}; |
108 | } |
109 | |
110 | Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node) final { |
111 | // The current scope has captured all the rewritten let-binding, as well as any additional |
112 | // bindings we needed to add. All we need is the rewritted body. |
113 | Expr new_body = post_let_node->body; |
114 | while (const auto* inner_let_node = new_body.as<LetNode>()) { |
115 | new_body = inner_let_node->body; |
116 | } |
117 | auto ret = scopes_.back().Get(new_body); |
118 | scopes_.pop_back(); |
119 | return ret; |
120 | } |
121 | |
122 | Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { |
123 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); |
124 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
125 | |
126 | if (device_copy_props.body.defined()) { |
127 | // Special case: device_copy calls remain in their original (and functional) form. |
128 | // TODO(mbs): device_copy cleanup. |
129 | return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node); |
130 | } |
131 | |
132 | if (!call_lowered_props.lowered_func.defined()) { |
133 | // This is a call to a user-defined Relay functinon, which will be handled directly by |
134 | // the VM and does not need conversion to DPS. |
135 | return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node); |
136 | } |
137 | |
138 | Call call = GetRef<Call>(call_node); |
139 | VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call); |
140 | |
141 | VirtualDevice virtual_device = GetVirtualDevice(call); |
142 | ICHECK(!virtual_device->IsFullyUnconstrained()); |
143 | ICHECK(!scopes_.empty()) |
144 | << "Calls out of a let block are not supported, do you forget to transform " |
145 | << "with ToANormalForm or set opt_level >= 1 in the pass context?" ; |
146 | LetList& scope = scopes_.back(); |
147 | |
148 | std::vector<Expr> new_args; |
149 | for (const auto& arg : call_lowered_props.arguments) { |
150 | new_args.push_back(Mutate(arg)); |
151 | } |
152 | Tuple ins(new_args); |
153 | Type ret_type = call_node->checked_type_; |
154 | std::vector<TensorType> out_types = FlattenTupleType(ret_type); |
155 | |
156 | // Handle reshape. |
157 | // Original: |
158 | // reshape(body, <ReshapeAttrs>) |
159 | // dyn.reshape(body, shape, <ReshapeAttrs>) |
160 | // After FuseOps: |
161 | // let %f = fn(x, primitive=1, relay.reshape_only=1) { reshape(x, <ReshapeAttrs>) } |
162 | // %f(body) |
163 | // After LowerTEPass: |
164 | // call_lowered(@xxx_reshape, (body), <LoweredCallAttrs with |
165 | // relay_attrs|->dict[relay.reshape_only] = 1) |
166 | // -OR- |
167 | // call_lowered(@xxx_dyn_reshape, (body, shape), <LoweredCallAttrs with same>) |
168 | // where @reshape_xxx is bound as a PrimFunc. |
169 | // (the name is irrelevant, only the relay.reshape_only attribute matters) |
170 | // After this pass: |
171 | // vm.reshape_tensor(body, shape, <TIRCallAttrs>) |
172 | if (IsReshapeOnly(call_lowered_props)) { |
173 | return EmitReshapeTensor(&scope, ins, call_lowered_props.attrs, ret_type); |
174 | } |
175 | |
176 | // At this point we could be calling a PrimFunc or an 'external' and already compiled primitive. |
177 | // The calling conventions are identical. |
178 | |
179 | // Handle 'dynamic' calls, ie to PrimFuncs whose result shape must be first computed |
180 | // by a companion shape function. |
181 | if (IsDynamic(ret_type)) { |
182 | return DynamicInvoke(&scope, call_lowered_props.lowered_func, ins, call_lowered_props.attrs, |
183 | out_types, ret_type, virtual_device); |
184 | } |
185 | |
186 | // Handle ordinary primitive calls. |
187 | Array<Expr> outputs; |
188 | for (size_t i = 0; i < out_types.size(); ++i) { |
189 | outputs.push_back( |
190 | MakeStaticAllocation(&scope, out_types[i], virtual_device, std::to_string(i))); |
191 | } |
192 | Tuple outs(outputs); |
193 | Expr invoke = |
194 | InvokeTVMOp(call_lowered_props.lowered_func, ins, outs, |
195 | Downcast<DictAttrs>(call_lowered_props.attrs.metadata.at("relay_attrs" ))); |
196 | scope.Push(MaybeOnDeviceFixed(invoke, virtual_device)); |
197 | return ToTupleType(ret_type, std::vector<Expr>(outputs.begin(), outputs.end())); |
198 | } |
199 | |
200 | /*! |
201 | * \brief Returns the Relay Constant representing the 1d tensor with \p value. |
202 | * |
203 | * CAUTION: Make sure the constant ends up on the correct device. |
204 | */ |
205 | inline Constant MakeConstant(const std::vector<int64_t>& value) { |
206 | return MakeConstantTensor(DataType::Int(64), {static_cast<int64_t>(value.size())}, value); |
207 | } |
208 | |
209 | /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */ |
210 | inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, |
211 | Array<IndexExpr> assert_shape) { |
212 | Expr offset = |
213 | MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_virtual_device_); |
214 | return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype, |
215 | assert_shape); |
216 | } |
217 | |
218 | Expr ComputeAlignment(const DataType& dtype) const { |
219 | int64_t align = dtype.bits() / 8 * dtype.lanes(); |
220 | if (align < 64) { |
221 | align = 64; |
222 | } |
223 | return MakeConstantScalar(DataType::Int(64), align); |
224 | } |
225 | |
226 | Expr ComputeStorageInRelay(const Expr& shape, const TensorType& type) const { |
227 | auto dtype = DataType(type->dtype); |
228 | Expr els = Prod(shape, Array<Integer>(nullptr), false, false); |
229 | Expr num = MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes()); |
230 | Expr add = Add(num, MakeConstantScalar(DataType::Int(64), 7)); |
231 | Expr div = MakeConstantScalar(DataType::Int(64), 8); |
232 | Expr ret = Multiply(els, Divide(add, div)); |
233 | return std::move(ret); |
234 | } |
235 | |
236 | Expr ComputeStorage(const TensorType& type) { |
237 | int64_t size = 1; |
238 | for (auto it : type->shape) { |
239 | auto val = it.as<IntImmNode>(); |
240 | CHECK(val); |
241 | size *= val->value; |
242 | } |
243 | size *= (type->dtype.bits() * type->dtype.lanes() + 7) / 8; |
244 | return std::move(MakeConstantScalar(DataType::Int(64), size)); |
245 | } |
246 | |
247 | // Allocate a tensor with a statically known shape. |
248 | Var MakeStaticAllocation(LetList* scope, const TensorType& type, |
249 | const VirtualDevice& virtual_device, String name_hint) { |
250 | std::vector<int64_t> int_shape; |
251 | for (auto it : type->shape) { |
252 | const auto* imm = it.as<IntImmNode>(); |
253 | CHECK(imm) << "expect static int shape" ; |
254 | int_shape.push_back(imm->value); |
255 | } |
256 | Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_virtual_device_); |
257 | Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_virtual_device_); |
258 | // Alignment is directly captured in the instruction rather than calculated, so we |
259 | // don't want to wrap it with an "on_device". |
260 | Expr alignment = ComputeAlignment(type->dtype); |
261 | // Run type inference later to get the correct type. |
262 | Var var("storage_" + name_hint, Type(nullptr)); |
263 | Expr value = AllocStorage(size, alignment, virtual_device, type->dtype); |
264 | auto sto = scope->Push(var, MaybeOnDeviceFixed(value, virtual_device)); |
265 | |
266 | // TODO(@jroesch): There is a bug with typing based on the constant shape. |
267 | auto tensor = AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape); |
268 | Var tensor_var("tensor_" + name_hint, Type(nullptr)); |
269 | return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, virtual_device)); |
270 | } |
271 | |
272 | /*! |
273 | * \brief Appends to \p scope the computation necessary to call the shape function given |
274 | * in \p tir_call_attrs and bind the resulting result shapes into \p scope. The result |
275 | * shapes are for a call to a primitive with \p ins arguments. Some combinationn of the |
276 | * data and/or shapes of \p ins will be needed by the shape function. |
277 | */ |
278 | Array<Expr> EmitShapeFunc(LetList* scope, const Tuple& ins, const CallLoweredAttrs& attrs) { |
279 | ICHECK(attrs.metadata.count("prim_shape_fn_states" )); |
280 | Array<Integer> input_states = |
281 | Downcast<Array<Integer>>(attrs.metadata.at("prim_shape_fn_states" )); |
282 | ICHECK(attrs.metadata.count("prim_shape_fn_var" )); |
283 | auto prim_fn_var = Downcast<GlobalVar>(attrs.metadata.at("prim_shape_fn_var" )); |
284 | |
285 | const auto* func_type_node = prim_fn_var->checked_type().as<FuncTypeNode>(); |
286 | ICHECK(func_type_node); |
287 | |
288 | // Establish the arguments to the shape function. |
289 | Array<Expr> shape_func_ins; |
290 | int input_pos = 0; |
291 | ICHECK_EQ(ins->fields.size(), input_states.size()); |
292 | for (size_t i = 0; i < ins->fields.size(); ++i) { |
293 | const Expr& arg = ins->fields[i]; |
294 | Type ty; |
295 | if (const auto* vn = arg.as<VarNode>()) { |
296 | ty = vn->type_annotation; |
297 | } else { |
298 | ty = arg->checked_type(); |
299 | } |
300 | int64_t state = input_states[i]->value; |
301 | // Pass Shapes |
302 | if (state == tec::kNeedInputShape) { |
303 | std::vector<Expr> exprs = FromTupleType(ty, arg); |
304 | for (size_t j = 0; j < exprs.size(); ++j) { |
305 | Expr sh_of = Mutate(ShapeOf(exprs[j])); |
306 | Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); |
307 | shape_func_ins.push_back( |
308 | scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_virtual_device_))); |
309 | input_pos++; |
310 | } |
311 | } else if (state == tec::kNeedInputData) { |
312 | auto new_arg = Mutate(arg); // already accounts for device |
313 | VirtualDevice arg_virtual_device = GetVirtualDevice(arg); |
314 | ICHECK(!arg_virtual_device->IsFullyUnconstrained()); |
315 | // The dynamic shape function is expecting its data on the host/CPU, so insert a |
316 | // device_copy otherwise. (We'll need to fuse & lower these copies in the same way |
317 | // we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.) |
318 | new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_virtual_device), |
319 | arg_virtual_device, host_virtual_device_); |
320 | Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); |
321 | shape_func_ins.push_back( |
322 | scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_virtual_device_))); |
323 | input_pos++; |
324 | } else { |
325 | // TODO(@jroesch): handle kNeedBoth |
326 | LOG(FATAL) << "unsupported shape function input state" ; |
327 | } |
328 | } |
329 | ICHECK_EQ(shape_func_ins.size(), func_type_node->arg_types.size()); |
330 | |
331 | // Establish the result shapes. |
332 | const auto* res_tuple_node = func_type_node->ret_type.as<TupleTypeNode>(); |
333 | ICHECK(res_tuple_node); |
334 | |
335 | Array<Expr> out_shapes; |
336 | for (size_t i = 0; i < res_tuple_node->fields.size(); ++i) { |
337 | const auto* tensor_type_node = res_tuple_node->fields[i].as<TensorTypeNode>(); |
338 | ICHECK(tensor_type_node); |
339 | // Put the shape func on the host. This also ensures that everything between |
340 | // shape_of and shape_func is similarly on the host. |
341 | Var alloc = MakeStaticAllocation(scope, GetRef<TensorType>(tensor_type_node), |
342 | host_virtual_device_, "out_shape_" + std::to_string(i)); |
343 | out_shapes.push_back(alloc); |
344 | } |
345 | |
346 | // Represent the call in DPS form. |
347 | auto shape_call = InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), |
348 | Downcast<DictAttrs>(attrs.metadata.at("relay_attrs" ))); |
349 | Var shape_func_var("shape_func" , Type(nullptr)); |
350 | scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_virtual_device_)); |
351 | return out_shapes; |
352 | } |
353 | |
354 | // Generate the code for invoking the TVM primitive \p func who's results have dynamic shapes. |
355 | Expr DynamicInvoke(LetList* scope, const Expr& func, const Tuple& ins, |
356 | const CallLoweredAttrs& attrs, const std::vector<TensorType>& out_types, |
357 | const Type& ret_type, const VirtualDevice& virtual_device) { |
358 | Array<Expr> out_shapes = EmitShapeFunc(scope, ins, attrs); |
359 | std::vector<Var> storages; |
360 | CHECK_EQ(out_shapes.size(), out_types.size()); |
361 | for (size_t i = 0; i < out_shapes.size(); ++i) { |
362 | auto out_shape = out_shapes[i]; |
363 | auto out_type = out_types[i]; |
364 | auto size = |
365 | MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_virtual_device_); |
366 | // Alignment is directly captured in the instruction so don't wrap in "on_device". |
367 | auto alignment = ComputeAlignment(out_type->dtype); |
368 | Var sto_var("storage_" + std::to_string(i), Type(nullptr)); |
369 | auto val = AllocStorage(size, alignment, virtual_device, out_type->dtype); |
370 | storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, virtual_device))); |
371 | } |
372 | |
373 | Array<Expr> outs; |
374 | for (size_t i = 0; i < storages.size(); ++i) { |
375 | auto out_shape = out_shapes[i]; |
376 | auto out_type = out_types[i]; |
377 | auto storage = storages[i]; |
378 | auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape); |
379 | Var out_var("out_" + std::to_string(i), Type(nullptr)); |
380 | outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, virtual_device))); |
381 | } |
382 | |
383 | Tuple tuple_outs(outs); |
384 | auto call = |
385 | InvokeTVMOp(func, ins, tuple_outs, Downcast<DictAttrs>(attrs.metadata.at("relay_attrs" ))); |
386 | scope->Push(MaybeOnDeviceFixed(call, virtual_device)); |
387 | return ToTupleType(ret_type, |
388 | std::vector<Expr>(tuple_outs->fields.begin(), tuple_outs->fields.end())); |
389 | } |
390 | |
391 | Expr EmitReshapeTensor(LetList* scope, const Tuple& ins, const CallLoweredAttrs& attrs, |
392 | const Type& ret_type) { |
393 | ICHECK_GE(ins->fields.size(), 1); // static reshape |
394 | ICHECK_LE(ins->fields.size(), 2); // dynamic reshape, second arg is shape |
395 | TensorType ret_ty = Downcast<TensorType>(ret_type); |
396 | Expr shape_expr; |
397 | if (IsDynamic(ret_type)) { |
398 | // Even though the desired output shape has been passed as the second argument to |
399 | // the dyn.reshape primitive, we'll still call that primitive's shape function. Go figure. |
400 | Array<Expr> out_shapes = EmitShapeFunc(scope, ins, attrs); |
401 | ICHECK_EQ(out_shapes.size(), 1); |
402 | shape_expr = out_shapes[0]; |
403 | } else { |
404 | std::vector<int64_t> shape; |
405 | for (const auto& it : ret_ty->shape) { |
406 | const auto* imm = it.as<IntImmNode>(); |
407 | CHECK(imm) << "expect static int shape" ; |
408 | shape.push_back(imm->value); |
409 | } |
410 | shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_virtual_device_); |
411 | } |
412 | return ReshapeTensor(ins->fields[0], shape_expr, ret_ty->shape); |
413 | } |
414 | |
415 | private: |
416 | const Op& device_copy_op_ = Op::Get("device_copy" ); |
417 | runtime::DataType compute_dtype_ = runtime::DataType::Int(64); |
418 | IRModule mod_; |
419 | VirtualDevice host_virtual_device_; |
420 | |
421 | std::vector<LetList> scopes_; |
422 | }; |
423 | |
424 | namespace transform { |
425 | |
426 | Pass ManifestAllocImportStorage() { |
427 | auto pass_func = [](IRModule mod, tvm::transform::PassContext pass_cnxt) { |
428 | mod.CopyOnWrite(); |
429 | mod->ImportFromStd("core.rly" ); |
430 | return mod; |
431 | }; |
432 | return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/0, "ManifestAllocImportStorage" , |
433 | /*required=*/{}); |
434 | } |
435 | |
436 | Pass ManifestAllocImpl(VirtualDevice host_virtual_device) { |
437 | auto pass_func = [host_virtual_device](Function func, IRModule mod, PassContext ctxt) { |
438 | return DialectRewriter(mod, host_virtual_device).Rewrite(func); |
439 | }; |
440 | return CreateFunctionPass(pass_func, 0, "ManifestAllocImpl" , {}); |
441 | } |
442 | |
443 | Pass ManifestAlloc(VirtualDevice cpu_virtual_device) { |
444 | std::vector<Pass> passes = {ManifestAllocImportStorage(), InferType(), |
445 | ManifestAllocImpl(std::move(cpu_virtual_device)), InferType()}; |
446 | return Sequential(passes, "ManifestAlloc" ); |
447 | } |
448 | |
449 | TVM_REGISTER_GLOBAL("relay.transform.ManifestAlloc" ).set_body_typed(ManifestAlloc); |
450 | |
451 | } // namespace transform |
452 | |
453 | } // namespace relay |
454 | } // namespace tvm |
455 | |