1#include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
2
3#include <ATen/FunctionalTensorWrapper.h>
4#include <ATen/Functions.h>
5#include <ATen/core/boxing/KernelFunction.h>
6#include <ATen/native/CPUFallback.h>
7#include <torch/csrc/lazy/backend/backend_interface.h>
8#include <torch/csrc/lazy/core/config.h>
9#include <torch/csrc/lazy/core/metrics.h>
10#include <torch/csrc/lazy/core/tensor.h>
11#include <torch/library.h>
12#include <sstream>
13#include <unordered_map>
14
15namespace torch {
16namespace lazy {
17namespace {
18
19std::vector<at::Tensor> _to_eager(
20 at::TensorList tensors,
21 c10::DeviceType device_type) {
22 switch (device_type) {
23 case at::kCPU: {
24 return at::_to_cpu(tensors);
25 }
26 default: {
27 std::vector<at::Tensor> eager_tensors;
28 for (const auto& t : tensors) {
29 c10::TensorOptions options = t.options().device(device_type);
30 at::Tensor eager_tensor = t.to(
31 options,
32 /*non_blocking*/ false,
33 /*copy*/ false);
34 eager_tensors.push_back(eager_tensor);
35 }
36 return eager_tensors;
37 }
38 }
39}
40
41// convenience helper for converting tensors to cpu
42
43std::vector<at::Tensor> to_eager(
44 const at::TensorList& tensors,
45 c10::DeviceType device_type) {
46 // We can't just call _to_eager() on the entire list of Tensors because it
47 // will break on undefined tensors. Separate out undefined tensors first.
48 std::vector<at::Tensor> eager_tensors(tensors.size());
49 std::vector<at::Tensor> valid_tensors;
50 std::vector<bool> to_translate(tensors.size());
51 for (size_t i = 0; i < tensors.size(); ++i) {
52 const at::Tensor& tensor = tensors[i];
53 // Explicitly handling undefined tensors here instead of letting `_to_eager`
54 // handle it. Otherwise, we'd need to require all backends with their own
55 // implementation of _to_eager to properly handle undefined tensors.
56 if (tensor.defined()) {
57 to_translate[i] = true;
58 valid_tensors.push_back(tensor);
59 } else {
60 eager_tensors[i] = tensor;
61 }
62 }
63 auto eager_valid_tensors = _to_eager(valid_tensors, device_type);
64 for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
65 if (to_translate[i]) {
66 eager_tensors[i] = std::move(eager_valid_tensors[defined_pos++]);
67 }
68 }
69 return eager_tensors;
70}
71
72std::vector<c10::optional<at::Tensor>> to_eager(
73 const std::vector<c10::optional<at::Tensor>>& tensors,
74 c10::DeviceType device_type) {
75 // We can't just call _to_eager() on the entire list of Tensors because it
76 // will break on undefined tensors. Separate out undefined tensors first.
77 std::vector<c10::optional<at::Tensor>> eager_tensors(tensors.size());
78 std::vector<at::Tensor> valid_tensors;
79 std::vector<bool> to_translate(tensors.size());
80 for (size_t i = 0; i < tensors.size(); ++i) {
81 const c10::optional<at::Tensor>& tensor = tensors[i];
82 // Explicitly handling undefined tensors here instead of letting `_to_eager`
83 // handle it. Otherwise, we'd need to require all backends with their own
84 // implementation of _to_eager to properly handle undefined tensors.
85 if (tensor.has_value() && tensor->defined()) {
86 to_translate[i] = true;
87 valid_tensors.push_back(*tensor);
88 } else {
89 eager_tensors[i] = tensor;
90 }
91 }
92 auto eager_valid_tensors = _to_eager(valid_tensors, device_type);
93 for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
94 if (to_translate[i]) {
95 eager_tensors[i] = std::move(eager_valid_tensors[defined_pos++]);
96 }
97 }
98 return eager_tensors;
99}
100
101c10::DispatchKey dispatch_key(c10::DeviceType device_type) {
102 switch (device_type) {
103 case at::kCPU: {
104 return c10::DispatchKey::CPU;
105 }
106 case at::kCUDA: {
107 return c10::DispatchKey::CUDA;
108 }
109 default: {
110 AT_ERROR("Unsupported device type: ", device_type);
111 }
112 }
113}
114
115c10::optional<c10::Device> compute_target_device(
116 std::vector<at::Tensor>& t_args,
117 std::vector<c10::List<at::Tensor>> tlist_args,
118 std::vector<c10::List<c10::optional<at::Tensor>>> opt_tlist_args) {
119 // Decide what device to move the output tensor(s) to.
120 // The current convention is that we use the first tensor arg to pick the
121 // device Barring that, we take the first tensor from a TensorList arg.
122 if (!t_args.empty()) {
123 return t_args[0].device();
124 } else {
125 // We need to loop through all of the (potentially multiple) TensorList
126 // arguments In case, e.g. the first one is empty but the second is not.
127 for (auto& tens_list : tlist_args) {
128 for (const auto i : c10::irange(tens_list.size())) {
129 return tens_list.get(i).device();
130 }
131 }
132 for (auto& tens_list : opt_tlist_args) {
133 for (const auto i : c10::irange(tens_list.size())) {
134 if (tens_list.get(i).has_value()) {
135 return tens_list.get(i)->device();
136 }
137 }
138 }
139 }
140 return c10::nullopt;
141}
142
143} // namespace
144
145static std::unordered_map<std::string, ::torch::lazy::Counter*>
146 _eager_fallback_counters;
147
148bool force_eager_fallback(c10::Symbol op) {
149 auto force_str = getLTCForceFallback();
150 if (!force_str.empty()) {
151 static auto force_sym = c10::Symbol::fromQualString(std::string(force_str));
152 if (op == force_sym) {
153 return true;
154 }
155 }
156 if (op == at::aten::nonzero) {
157 // When symbolic shape mode is not enabled, the nonzero shape function
158 // returns an incorrect result.
159 return !symbolicShapeEnabled();
160 }
161
162 return false;
163}
164
165void ltc_eager_fallback(
166 const c10::OperatorHandle& op,
167 torch::jit::Stack* stack) {
168 // TODO(whc) this FN_TRACK thing hasn't been used so far in LTC iirc but could
169 // land/re-enable it LTC_FN_TRACK(3);;
170 const auto name = c10::toString(op.operator_name());
171
172 // Manually applying the TORCH_LAZY_COUNTER macro.
173 // We need to do it ourselves and explicitly keep a mapping of counters
174 // because this boxed fallback kernel is used by multiple operators,
175 // and the macro stamps out a static Counter object with a fixed name
176 // at the code location that it was called.
177 if (_eager_fallback_counters.find(name) == _eager_fallback_counters.end()) {
178 _eager_fallback_counters[name] = new ::torch::lazy::Counter(name);
179 }
180 _eager_fallback_counters[name]->AddValue(1);
181
182 auto& args = op.schema().arguments();
183 auto arguments = torch::jit::last(stack, args.size());
184
185 // Log each tensor argument.
186 for (const auto& ivalue : arguments) {
187 if (ivalue.isTensor()) {
188 VLOG(3) << ivalue.toTensor().toString();
189 }
190 }
191
192 // Call the actual boxed CPU fallback.
193 ts_eager_fallback(
194 op, stack, torch::lazy::getBackend()->EagerFallbackDeviceType());
195}
196
197void register_ts_ltc_eager_fallback() {
198 static auto m = MAKE_TORCH_LIBRARY_IMPL(_, Lazy);
199 // Most backends use TORCH_LIBRARY_* macros which perform their dispatcher
200 // registrations at static library init time, but the lazy Torchscript backend
201 // does not since it is built in the main torch lib but not always used.
202 // In particular, if another external backend wants to register itself to the
203 // same key (Lazy), Torchscript backend must not be initialized.
204 m.fallback(torch::CppFunction::makeFromBoxedFunction<&ltc_eager_fallback>());
205}
206
207void ts_eager_fallback(
208 const c10::OperatorHandle& op,
209 torch::jit::Stack* stack,
210 c10::DeviceType device_type) {
211 auto& schema_args = op.schema().arguments();
212 const auto num_arguments = schema_args.size();
213 auto arguments = torch::jit::last(stack, num_arguments);
214 const auto arguments_begin = stack->size() - num_arguments;
215
216 std::vector<at::Tensor> tensor_args;
217 std::vector<int> tensor_args_indices;
218
219 std::vector<c10::List<at::Tensor>> tensorlist_args;
220 std::vector<c10::List<c10::optional<at::Tensor>>> opt_tensorlist_args;
221
222 // Step 1: Convert all non-eager tensor inputs into eager tensors and put them
223 // on the stack at the correct indices.
224 for (int64_t idx = 0; idx < arguments.size(); ++idx) {
225 const auto& ivalue = arguments[idx];
226 if (ivalue.isTensor()) {
227 tensor_args.push_back(ivalue.toTensor());
228 tensor_args_indices.push_back(idx);
229 } else if (ivalue.isTensorList()) {
230 // Note: we copy each TensorList argument to eager individually out of
231 // convenience, but XLA would benefit from materializing all tensor and
232 // TensorList args onto the CPU at the same time. We can improve this if
233 // we need better perf for XLA's CPU fallbacks.
234 auto eager_ivalue = c10::IValue(c10::List<at::Tensor>(
235 to_eager(ivalue.toTensorVector(), device_type)));
236 (*stack)[arguments_begin + idx] = std::move(eager_ivalue);
237 tensorlist_args.push_back(ivalue.toTensorList());
238 } else if (ivalue.isOptionalTensorList()) {
239 auto eager_ivalue = c10::IValue(c10::List<c10::optional<at::Tensor>>(
240 to_eager(ivalue.toOptionalTensorVector(), device_type)));
241 (*stack)[arguments_begin + idx] = std::move(eager_ivalue);
242 opt_tensorlist_args.push_back(ivalue.toOptionalTensorList());
243 }
244 }
245 // XLA requires all of the tensor arguments to be gathered up and converted to
246 // CPU together.
247 auto eager_tensors = to_eager(tensor_args, device_type);
248
249 for (auto i = 0; i < tensor_args_indices.size(); ++i) {
250 auto idx = tensor_args_indices[i];
251 (*stack)[arguments_begin + idx] = c10::IValue(eager_tensors[i]);
252 }
253
254 // Step 2: Call the underlying eager implementation of the operator
255 op.redispatchBoxed(c10::DispatchKeySet(dispatch_key(device_type)), stack);
256
257 // Step 3: We need to take special care to handle mutable aliases properly:
258 // If any input tensors are mutable aliases, we need to directly copy the
259 // updated data on the eager tensors back to the original inputs.
260 for (int64_t i = 0; i < tensor_args_indices.size(); ++i) {
261 auto tensor_idx = tensor_args_indices[i];
262 const auto alias_info = schema_args[tensor_idx].alias_info();
263 if (alias_info != nullptr && alias_info->isWrite()) {
264 at::_copy_from_and_resize(eager_tensors[i], tensor_args[i]);
265 }
266 }
267
268 // Step 4: Convert any eager output tensors back to the original input device.
269 // For mutable alias'd outputs, we also need to take special care
270 // to move the ORIGINAL input tensor back onto the stack, in place of
271 // the temporary eager output tensor that we created.
272 //
273 // Note [Eager Fallback Does Not Handle View Operators]
274 // Also note that we are incapable of handling immutable alises properly.
275 // Why?
276 // Schemas with an immutable alias'd tensor outputs correspond to view
277 // operators. For example, the `view_as` schema from native_functions.yaml:
278 // `view_as(Tensor(a) self, Tensor other) -> Tensor(a)`
279 // We can't handle these ops properly, because view ops are supposed to return
280 // a NEW tensor that shares the SAME storage as the original tensor.
281 // However, the new tensor that we created cannot share the same storage,
282 // since it lives on the eager CPU / CUDA device and the original tensor lives
283 // on a different device. Because of that, we warn if someone attempts to call
284 // the eager fallback on a view operator (this is to maintain BC for view ops
285 // for XLA that fall back to CPU).
286 const auto& schema_returns = op.schema().returns();
287 const auto& num_returns = schema_returns.size();
288 auto returns = torch::jit::last(stack, num_returns);
289 const auto returns_begin = stack->size() - num_returns;
290
291 for (int64_t idx = 0; idx < returns.size(); ++idx) {
292 if (returns[idx].isTensor()) {
293 const auto& return_tens = returns[idx].toTensor();
294 if (return_tens.defined()) {
295 const auto alias_info = schema_returns[idx].alias_info();
296 if (alias_info != nullptr && alias_info->isWrite()) {
297 // Case (1): mutable alias case. Move the input ivalue directly onto
298 // the stack in place of the existing eager output tensor.
299 bool found_alias = false;
300 // We could store some extra metadata on the function schema to avoid
301 // the loop here if we need to improve perf.
302 for (int64_t i = 0; i < tensor_args_indices.size(); ++i) {
303 auto input_tensor_idx = tensor_args_indices[i];
304 const auto& input_tensor = eager_tensors[i];
305 const auto input_alias_info =
306 schema_args[input_tensor_idx].alias_info();
307 if (input_tensor.defined() && input_alias_info != nullptr &&
308 *alias_info == *input_alias_info) {
309 // We've found the original input tensor that aliases with the
310 // current output. Wrap it in an IValue and put it directly on the
311 // stack.
312 (*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]);
313 found_alias = true;
314 break;
315 }
316 }
317 TORCH_CHECK(
318 found_alias,
319 "The operator ",
320 op.schema().operator_name(),
321 " appears to have invalid alias information. ",
322 "Found a return tensor argument with a mismatched "
323 "mutable alias: ",
324 schema_returns[idx]);
325 } else {
326 c10::optional<c10::Device> tgt_device = compute_target_device(
327 tensor_args, tensorlist_args, opt_tensorlist_args);
328 if (alias_info != nullptr && !alias_info->isWrite()) {
329 // immutable alias (view) case: Warn here, since we're copying and
330 // not creating a view.
331 // If this operator is needed, the backend should provide a kernel
332 // for it.
333 // See Note [Eager Fallback Does Not Handle View Operators]
334 std::stringstream dev_str;
335 if (tgt_device) {
336 dev_str << *tgt_device;
337 } else {
338 dev_str << "<none>";
339 }
340 // We should never hit this for a view op,
341 // because LazyTensor should provide a lowering for the
342 // corresponding view_copy operator. The functionalization pass will
343 // take care of calling the view_copy operator intead of the view.
344 TORCH_CHECK(
345 false,
346 "The operator ",
347 op.schema().operator_name(),
348 " appears to be a view operator, ",
349 "but it has no implementation for the backend \"",
350 dev_str.str(),
351 "\". View operators don't support ",
352 "falling back to run on the eager, since the tensor's "
353 "storage cannot be shared across devices.");
354 }
355 // Case (2): copy case. Copy the eager output tensor to the original
356 // device.
357
358 // We technically might not have a target device, e.g. if you call
359 // torch.cat() with an empty list In that case, we shouldn't have any
360 // tensors to schlep across devices anyway.
361 if (tgt_device) {
362 (*stack)[returns_begin + idx] =
363 c10::IValue(returns[idx].toTensor().to(*tgt_device));
364 }
365 }
366 }
367 }
368 }
369}
370
371} // namespace lazy
372} // namespace torch
373