1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/memory_types.h" |
17 | |
18 | #include <utility> |
19 | |
20 | #include "tensorflow/compiler/jit/defs.h" |
21 | #include "tensorflow/core/framework/attr_value.pb.h" |
22 | #include "tensorflow/core/framework/kernel_def.pb.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/node_def_util.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/lib/core/errors.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | namespace { |
33 | // Returns the largest endpoint of anything in the name_map. |
34 | int GetTotal(const NameRangeMap& name_map) { |
35 | int total = 0; |
36 | for (const auto& item : name_map) { |
37 | total = std::max(total, item.second.second); |
38 | } |
39 | return total; |
40 | } |
41 | |
42 | // Fills memory_types for either input or output, setting everything |
43 | // to DEVICE_MEMORY except those args in host_memory_args. Removes |
44 | // elements of host_memory_args that were used. |
45 | void MemoryTypesHelper(const NameRangeMap& name_map, |
46 | std::vector<string>* host_memory_args, |
47 | MemoryTypeVector* memory_types) { |
48 | // Update args that have been marked as in "HOST_MEMORY". |
49 | size_t keep = 0; |
50 | for (size_t i = 0; i < host_memory_args->size(); ++i) { |
51 | auto iter = name_map.find((*host_memory_args)[i]); |
52 | if (iter != name_map.end()) { |
53 | for (int j = iter->second.first; j < iter->second.second; ++j) { |
54 | (*memory_types)[j] = HOST_MEMORY; |
55 | } |
56 | } else { |
57 | // (*host_memory_args)[i] not found, save it for the next pass. |
58 | if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i]; |
59 | ++keep; |
60 | } |
61 | } |
62 | host_memory_args->resize(keep); |
63 | } |
64 | |
65 | bool IsFunctionCallOp(const string& op_type) { |
66 | return op_type == "SymbolicGradient" || op_type == "PartitionedCall" || |
67 | op_type == "StatefulPartitionedCall" || op_type == "While" || |
68 | op_type == "StatelessWhile" ; |
69 | } |
70 | |
71 | } // namespace |
72 | |
73 | MemoryType MTypeFromDType(const DataType dtype) { |
74 | return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY |
75 | : DEVICE_MEMORY; |
76 | } |
77 | |
78 | MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype) { |
79 | return DataTypeAlwaysOnHost(dtype) ? HOST_MEMORY : DEVICE_MEMORY; |
80 | } |
81 | |
82 | Status MemoryTypesForNode(const OpRegistryInterface* op_registry, |
83 | const DeviceType& device_type, const NodeDef& ndef, |
84 | MemoryTypeVector* inp_mtypes, |
85 | MemoryTypeVector* out_mtypes) { |
86 | // Look up the Op registered for this op name. |
87 | const OpDef* op_def; |
88 | TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(ndef.op(), &op_def)); |
89 | |
90 | // Look up the Kernel registered for this node def. |
91 | const KernelDef* kdef = nullptr; |
92 | Status status = |
93 | FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */); |
94 | |
95 | DataTypeVector inp_dtypes; |
96 | DataTypeVector out_dtypes; |
97 | TF_RETURN_IF_ERROR( |
98 | InOutTypesForNode(ndef, *op_def, &inp_dtypes, &out_dtypes)); |
99 | |
100 | inp_mtypes->clear(); |
101 | out_mtypes->clear(); |
102 | |
103 | bool has_xla_compile = [&] { |
104 | const auto& it = ndef.attr().find(kXlaMustCompileAttr); |
105 | return it != ndef.attr().end() && it->second.b(); |
106 | }(); |
107 | |
108 | bool has_kernel_def = status.ok() && !IsFunctionCallOp(ndef.op()); |
109 | auto host_memory_required = [&](const DataType& dt) { |
110 | bool int32_on_device = |
111 | has_kernel_def || device_type.type_string() == "TPU" || has_xla_compile; |
112 | return DataTypeAlwaysOnHost(dt) || (dt == DT_INT32 && !int32_on_device); |
113 | }; |
114 | |
115 | if (has_kernel_def) { |
116 | // Gets the input/output names and their corresponding endpoint ranges. |
117 | NameRangeMap inp_names; |
118 | NameRangeMap out_names; |
119 | TF_RETURN_IF_ERROR( |
120 | NameRangesForNode(ndef, *op_def, &inp_names, &out_names)); |
121 | |
122 | // Now that we know the size, fill with the default 'DEVICE_MEMORY'. |
123 | inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY); |
124 | out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); |
125 | |
126 | // Fills in host memory types based on the kernel def. |
127 | const auto& from_proto = kdef->host_memory_arg(); |
128 | std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); |
129 | MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); |
130 | MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); |
131 | if (!host_memory_args.empty()) { |
132 | return errors::InvalidArgument( |
133 | "HostMemory args '" , absl::StrJoin(host_memory_args, "', '" ), |
134 | "' not found in OpDef: " , SummarizeOpDef(*op_def)); |
135 | } |
136 | } else { |
137 | // Set all the datatype to DEVICE_MEMORY by default, later on change it to |
138 | // HOST_MEMORY where it is required by the datatype. |
139 | inp_mtypes->resize(inp_dtypes.size(), DEVICE_MEMORY); |
140 | out_mtypes->resize(out_dtypes.size(), DEVICE_MEMORY); |
141 | } |
142 | CHECK_LE(inp_mtypes->size(), inp_dtypes.size()); |
143 | CHECK_LE(out_mtypes->size(), out_dtypes.size()); |
144 | |
145 | // Mark e.g. all resource and string types as host memory. |
146 | for (int i = 0; i < inp_mtypes->size(); ++i) { |
147 | if (host_memory_required(inp_dtypes[i])) { |
148 | (*inp_mtypes)[i] = HOST_MEMORY; |
149 | } |
150 | } |
151 | for (int i = 0; i < out_mtypes->size(); ++i) { |
152 | if (host_memory_required(out_dtypes[i])) { |
153 | (*out_mtypes)[i] = HOST_MEMORY; |
154 | } |
155 | } |
156 | |
157 | std::vector<int32> hostmem_attr; |
158 | if (TryGetNodeAttr(ndef, "_input_hostmem" , &hostmem_attr)) { |
159 | for (int32_t i : hostmem_attr) { |
160 | if (0 <= i && i < inp_mtypes->size()) { |
161 | (*inp_mtypes)[i] = HOST_MEMORY; |
162 | } |
163 | } |
164 | } |
165 | hostmem_attr.clear(); |
166 | if (TryGetNodeAttr(ndef, "_output_hostmem" , &hostmem_attr)) { |
167 | for (int32_t i : hostmem_attr) { |
168 | if (0 <= i && i < out_mtypes->size()) { |
169 | (*out_mtypes)[i] = HOST_MEMORY; |
170 | } |
171 | } |
172 | } |
173 | |
174 | return OkStatus(); |
175 | } |
176 | |
177 | } // namespace tensorflow |
178 | |