1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/device_constraint_utils.cc |
22 | * \brief Utilities for extracting and applying device-related constraints to \p PrimFunc |
23 | * parameters. |
24 | * |
25 | * These utilities are used by the \p PlanDevices pass to extract memory (aka 'storage') scope |
26 | * information from \p PrimFuncs and convert them back into \p VirtualDevice form w.r.t. the |
27 | * original Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and |
28 | * conversion to destination-passing style aka DPS). |
29 | * |
30 | * A utility is also supplied to go the other way: impose memory scopes on \p PrimFunc parameters. |
31 | * However that's still in EXPERIMENTAL form. |
32 | * |
33 | * We may extend these utilities to also gather/apply layout information should we add that to |
34 | * \p VirtualDevice. |
35 | */ |
36 | |
37 | #ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ |
38 | #define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ |
39 | |
40 | #include <tvm/target/virtual_device.h> |
41 | #include <tvm/tir/function.h> |
42 | |
43 | namespace tvm { |
44 | namespace tir { |
45 | |
46 | /*! |
47 | * A Relay Function with type: |
48 | * \code |
49 | * fn((Tensor[...], Tensor[...]), Tensor[...]) -> (Tensor[...], Tensor[...]) |
50 | * ^ ^ ^ ^ ^ |
51 | * a b c d e |
52 | * \endcode |
53 | * will be represented by a TIR PrimFunc in flattened and DPS form with at least 5 argument a..e. |
54 | * \code |
55 | * primfn(a: handle, b: handle, c: handle, d: handle, e: handle) { |
56 | * buffers = { ... } |
57 | * buffer_map = { ... } |
58 | * ... |
59 | * } |
60 | * \endcode |
61 | * |
62 | * Each such PrimFunc argument will me mapped to a \p Buffer who's underlying \p data \p Var |
63 | * has a \p PointerType. |
64 | * |
65 | * The PrimFunc may have additional non-pointer arguments, eg for: |
66 | * - scalar inputs and tensor dimensions |
67 | * - device contexts |
68 | * Those should be ignored here since they have no counterpart in the Relay Function. |
69 | * |
70 | * We'll need helpers to map on-the-fly between the Relay and TIR view of functions. |
71 | */ |
72 | |
73 | /*! |
74 | * \brief Returns the \p VirtualDevices capturing the memory (aka storage) scope constraints for all |
75 | * the arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's |
76 | * representation as a Relay \p Function of \p relay_func_type_ before lowering and conversion to |
77 | * DPS. |
78 | */ |
79 | Array<VirtualDevice> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, |
80 | const FuncType& relay_func_type); |
81 | |
82 | /* |
83 | * \brief Returns \p prim_func written to capture the memory (aka storage) scope constraints |
84 | * for each of the \p prim_func's parameters given by \p arg_and_result_virtual_devices. However, |
85 | * \p arg_and_result_virtual_devices should be w.r.t. the \p prim_func's representation as a Relay |
86 | * \p Function of \p relay_func_type before lowering and conversion to DPS. |
87 | * |
88 | * CAUTION: This is experimental. The resulting \p PrimFunc may not have fully accounted for all |
89 | * new memory scopes. |
90 | */ |
91 | PrimFunc ApplyPrimFuncArgAndResultConstraints( |
92 | const PrimFunc& prim_func, const FuncType& relay_func_type, |
93 | const Array<VirtualDevice>& arg_and_result_virtual_devices); |
94 | |
95 | } // namespace tir |
96 | } // namespace tvm |
97 | |
98 | #endif // TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ |
99 | |