1 | #include <ATen/FunctionalTensorWrapper.h> |
2 | #include <ATen/Functions.h> |
3 | #include <ATen/MetaFunctions.h> |
4 | #include <ATen/NativeFunctions.h> |
5 | #include <ATen/Operators.h> |
6 | #include <ATen/native/BinaryOps.h> |
7 | #include <ATen/native/CPUFallback.h> |
8 | #include <torch/csrc/lazy/core/helpers.h> |
9 | #include <torch/csrc/lazy/core/ir_builder.h> |
10 | #include <torch/csrc/lazy/core/metrics.h> |
11 | #include <torch/csrc/lazy/core/ops/utils.h> |
12 | #include <torch/csrc/lazy/core/shape_inference.h> |
13 | #include <torch/csrc/lazy/core/tensor_impl.h> |
14 | #include <torch/csrc/lazy/core/tensor_util.h> |
15 | #include <torch/csrc/lazy/generated/LazyNativeFunctions.h> |
16 | #include <torch/csrc/lazy/ts_backend/config.h> |
17 | #include <torch/csrc/lazy/ts_backend/ops/random_ops.h> |
18 | #include <torch/csrc/lazy/ts_backend/ops/to_copy.h> |
19 | #include <torch/csrc/lazy/ts_backend/tensor_aten_ops.h> |
20 | #include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h> |
21 | #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h> |
22 | #include <torch/library.h> |
23 | |
24 | using at::Tensor; |
25 | |
26 | namespace torch { |
27 | namespace lazy { |
28 | namespace { |
29 | |
30 | at::Tensor CreateLtcTensor( |
31 | const at::Tensor& tensor, |
32 | const c10::optional<torch::lazy::BackendDevice>& device) { |
33 | if (tensor.defined() && device) { |
34 | return torch::lazy::CreateAtenFromLtcTensor( |
35 | torch::lazy::LazyTensor::Create(tensor, *device)); |
36 | } |
37 | return tensor; |
38 | } |
39 | |
40 | c10::optional<torch::lazy::BackendDevice> GetLtcDevice( |
41 | const c10::optional<c10::Device>& device) { |
42 | if (!device) { |
43 | return c10::nullopt; |
44 | } |
45 | if (device->type() != at::kLazy) { |
46 | return c10::nullopt; |
47 | } |
48 | return torch::lazy::atenDeviceToBackendDevice(*device); |
49 | } |
50 | |
51 | } // namespace |
52 | |
53 | // clone is special in LT because we make it a no-op. |
54 | // This should be safe to do, because every operator in the LT is functional. |
55 | at::Tensor LazyNativeFunctions::clone( |
56 | const at::Tensor& self, |
57 | c10::optional<at::MemoryFormat> memory_format) { |
58 | auto self_lt = torch::lazy::TryGetLtcTensor(self); |
59 | return torch::lazy::CreateAtenFromLtcTensor( |
60 | self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice())); |
61 | } |
62 | |
63 | at::Tensor LazyNativeFunctions::_copy_from( |
64 | const at::Tensor& self, |
65 | const at::Tensor& dst, |
66 | bool non_blocking) { |
67 | TORCH_LAZY_FN_COUNTER("lazy::" ); |
68 | auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); |
69 | auto self_tensor = torch::lazy::TryGetLtcTensor(self); |
70 | if (!self_tensor) { |
71 | // providing a new 'eager' value (self) for an existing lazy tensor (dst) |
72 | static bool sync_update = FLAGS_torch_lazy_ts_tensor_update_sync; |
73 | CHECK(dst_tensor); |
74 | dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); |
75 | } else if (!dst_tensor) { |
76 | // materializing a lazy tensor (self) and copying its value into eager |
77 | // tensor (dst) detached=false lets us skip a copy in `ToTensor`, which |
78 | // should be safe because we are only going to use the tensor for |
79 | // dst.copy_() |
80 | CHECK(self_tensor); |
81 | at::Tensor tensor = self_tensor->ToTensor(/*detached=*/false); |
82 | at::Tensor typed_tensor = |
83 | torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); |
84 | dst.resize_as_(typed_tensor).copy_(typed_tensor); |
85 | } else { |
86 | // Copying one lazy tensor to another |
87 | if (!dst_tensor->CurrentIrValue()) { |
88 | // if dest is not backed by IR (e.g. result of some lazy operation), |
89 | // then it should have at::Tensor data backing it instead |
90 | auto dst_tensor_data = dst_tensor->CurrentTensorData(); |
91 | CHECK(dst_tensor_data); |
92 | auto src_tensor_data = self_tensor->CurrentTensorData(); |
93 | if (src_tensor_data) { |
94 | // both src/dst are simply backed by at::Tensor data, no IR- do a |
95 | // straightforward copy |
96 | dst_tensor_data->copy_(*src_tensor_data); |
97 | } else { |
98 | // src needs to be materialized before its result can be used for a copy |
99 | // into dst since we use the src tensor only for making a copy, we don't |
100 | // need to detach it note: it would be even more efficient if we could |
101 | // cause ToTensor to materialize the value directly into dst's buffer |
102 | // (that would need to be detached though). |
103 | dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false)); |
104 | } |
105 | } else { |
106 | copy_(dst_tensor, self_tensor); |
107 | auto* impl = |
108 | dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl()); |
109 | impl->set_tensor(dst_tensor); |
110 | } |
111 | } |
112 | return dst; |
113 | } |
114 | |
115 | at::Tensor LazyNativeFunctions::_copy_from_and_resize( |
116 | const at::Tensor& self, |
117 | const at::Tensor& dst) { |
118 | TORCH_LAZY_FN_COUNTER("lazy::" ); |
119 | auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); |
120 | auto self_tensor = torch::lazy::TryGetLtcTensor(self); |
121 | if (!self_tensor) { |
122 | CHECK(dst_tensor); |
123 | dst_tensor->UpdateFromTensorOut(self); |
124 | } else if (!dst_tensor) { |
125 | CHECK(self_tensor); |
126 | at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true); |
127 | at::Tensor typed_tensor = |
128 | torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); |
129 | dst.resize_as_(typed_tensor).copy_(typed_tensor); |
130 | } else { |
131 | // at this point we know dst is a lazy tensor |
132 | auto* dest_impl = |
133 | dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl()); |
134 | dest_impl->tensor()->UpdateFromTensorOut(self_tensor); |
135 | dest_impl->force_refresh_sizes(); |
136 | } |
137 | return dst; |
138 | } |
139 | |
140 | at::Tensor LazyNativeFunctions::_to_copy( |
141 | const at::Tensor& self, |
142 | c10::optional<at::ScalarType> dtype, |
143 | c10::optional<at::Layout> layout, |
144 | c10::optional<at::Device> device, |
145 | c10::optional<bool> pin_memory, |
146 | bool non_blocking, |
147 | c10::optional<at::MemoryFormat> memory_format) { |
148 | if (force_eager_fallback(at::aten::_to_copy)) { |
149 | TORCH_INTERNAL_ASSERT( |
150 | false, |
151 | "Fallback is currently impossible for _to_copy since the fallback helper itself reinvokes _to_copy" ); |
152 | } |
153 | |
154 | auto options = self.options(); |
155 | if (dtype) { |
156 | // I put each of these setters in a conditional instead of doing |
157 | // `self.options().dtype(dtype).layout(layout)... because calling |
158 | // .dtype(nullopt) on an options() that already has dtype appears to wipe it |
159 | options = options.dtype(dtype); |
160 | } |
161 | if (layout) { |
162 | options = options.layout(layout); |
163 | } |
164 | if (memory_format) { |
165 | options = options.memory_format(memory_format); |
166 | } |
167 | if (pin_memory) { |
168 | // TODO(whc) can we honor 'pin_memory' in some/all cases? |
169 | options = options.pinned_memory(pin_memory); |
170 | TORCH_WARN_ONCE( |
171 | "Pinned memory used in lazy _to_copy, check if the behavior is as intended" ); |
172 | } |
173 | |
174 | TORCH_LAZY_FN_COUNTER("lazy::" ); |
175 | auto lazy_self = torch::lazy::TryGetLtcTensor(self); |
176 | if (!lazy_self && device && device->type() == c10::kLazy) { |
177 | // Case 1: eager->lazy (we create a new lazy tensor) |
178 | // See Note [Lazy Tensor Functionalization] |
179 | // Invariant: if the functionalization key is in the exclude set, then we're |
180 | // expected to return an ordinary tensor, which will be "lifted" into a |
181 | // functional wrapper later. |
182 | bool functionalize_output = |
183 | !c10::impl::tls_local_dispatch_key_set().excluded_.has( |
184 | c10::DispatchKey::Functionalize); |
185 | return torch::lazy::to_lazy_tensor( |
186 | self, |
187 | options, |
188 | *device, |
189 | /*non_blocking=*/non_blocking, |
190 | /*functionalize_output=*/functionalize_output); |
191 | } else if (device && device->type() != c10::kLazy) { |
192 | // Case 2: lazy->eager (forces a graph break since we are materializing a |
193 | // tensor) |
194 | |
195 | TORCH_INTERNAL_ASSERT(lazy_self); |
196 | auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); |
197 | options = options.device(device); |
198 | auto moved_eager_tensor = |
199 | eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); |
200 | return moved_eager_tensor; |
201 | } else if ( |
202 | device && device->type() == c10::kLazy && device->has_index() && |
203 | device->index() != self.device().index()) { |
204 | // Case 3: lazy:0 -> lazy:1 |
205 | |
206 | // TODO(whc) what do we actually want to do here? |
207 | // option 1: materialize, move eager tensor, create new lazy tensor |
208 | // - this should be our default, as it is what would happen before we |
209 | // implemented _to_copy |
210 | // - actually combines case 1 + case 2 |
211 | // option 2: support multiple devices inside one lazy/TS executor (case 4) |
212 | // - but: we may have other assumptions that there is just one device |
213 | // per executor? so don't take this lightly |
214 | |
215 | TORCH_INTERNAL_ASSERT(lazy_self); |
216 | auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); |
217 | // we move the eager tensor to the 'eager' equivalent of our lazy device |
218 | // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is |
219 | // what we use |
220 | auto eager_device = c10::Device( |
221 | torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); |
222 | options = options.device(eager_device); |
223 | auto moved_eager_tensor = |
224 | eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true); |
225 | lazy_self = torch::lazy::GetOrCreateLtcTensor( |
226 | moved_eager_tensor, |
227 | torch::lazy::atenDeviceToBackendDevice(eager_device)); |
228 | return torch::lazy::CreateAtenFromLtcTensor(lazy_self); |
229 | |
230 | } else { |
231 | // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy |
232 | // graph) |
233 | |
234 | // Note: captured _to_copy will be executed with real eager tensors, not |
235 | // lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this |
236 | // captured IR, or we will try to convert an eager tensor back to a lazy one |
237 | // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so |
238 | // we can safely drop the device argument |
239 | device = c10::nullopt; |
240 | |
241 | torch::lazy::NodePtr node = torch::lazy::ReuseNode<ToCopy>( |
242 | lazy_self->GetIrValue(), |
243 | dtype, |
244 | layout, |
245 | device, |
246 | pin_memory, |
247 | non_blocking, |
248 | memory_format); |
249 | if (!node) { |
250 | auto shapes = torch::lazy::compute_shape__to_copy( |
251 | self, dtype, layout, device, pin_memory, non_blocking, memory_format); |
252 | TORCH_INTERNAL_ASSERT(shapes.size() == 1); |
253 | node = torch::lazy::MakeNode<ToCopy>( |
254 | lazy_self->GetIrValue(), |
255 | dtype, |
256 | layout, |
257 | device, |
258 | pin_memory, |
259 | non_blocking, |
260 | memory_format, |
261 | std::move(shapes)); |
262 | CacheNode(node); |
263 | } |
264 | |
265 | auto result = |
266 | torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( |
267 | std::move(node), lazy_self->GetDevice())); |
268 | return result; |
269 | } |
270 | }; |
271 | |
272 | at::Tensor LazyNativeFunctions::empty_symint( |
273 | at::SymIntArrayRef sym_size, |
274 | c10::optional<at::ScalarType> dtype, |
275 | c10::optional<at::Layout> layout, |
276 | c10::optional<at::Device> device, |
277 | c10::optional<bool> pin_memory, |
278 | c10::optional<at::MemoryFormat> memory_format) { |
279 | // TODO: support this directly |
280 | auto size = C10_AS_INTARRAYREF_SLOW(sym_size); |
281 | const auto device_type = torch::lazy::getBackend()->EagerFallbackDeviceType(); |
282 | at::TensorOptions options = at::TensorOptions() |
283 | .device(c10::Device(device_type)) |
284 | .layout(layout) |
285 | .pinned_memory(pin_memory) |
286 | .dtype(dtype); |
287 | auto x_result = at::empty(size, options, memory_format); |
288 | auto tensor = CreateLtcTensor(x_result, GetLtcDevice(device)); |
289 | // See Note [Lazy Tensor Functionalization] |
290 | if (c10::impl::tls_local_dispatch_key_set().excluded_.has( |
291 | c10::DispatchKey::Functionalize)) { |
292 | // Invariant: if the functionalization key is in the exclude set, then we're |
293 | // expected to return an ordinary tensor, which will be "lifted" into a |
294 | // functional wrapper later. |
295 | return tensor; |
296 | } else { |
297 | auto wrapped = at::functionalization::impl::to_functional_tensor(tensor); |
298 | return wrapped; |
299 | } |
300 | } |
301 | |
302 | at::Tensor LazyNativeFunctions::empty_strided_symint( |
303 | at::SymIntArrayRef sym_size, |
304 | at::SymIntArrayRef sym_stride, |
305 | c10::optional<at::ScalarType> dtype, |
306 | c10::optional<at::Layout> layout, |
307 | c10::optional<at::Device> device, |
308 | c10::optional<bool> pin_memory) { |
309 | TORCH_LAZY_FN_COUNTER("lazy::" ); |
310 | at::Tensor t = |
311 | empty_symint(sym_size, dtype, layout, device, pin_memory, c10::nullopt); |
312 | auto size = C10_AS_INTARRAYREF_SLOW(sym_size); |
313 | auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride); |
314 | return t.as_strided(size, stride, /*storage_offset=*/0); |
315 | } |
316 | |
317 | at::Tensor& LazyNativeFunctions::fill_( |
318 | at::Tensor& self, |
319 | const at::Scalar& value) { |
320 | TORCH_LAZY_FN_COUNTER("lazy::" ); |
321 | auto self_tensor = torch::lazy::TryGetLtcTensor(self); |
322 | torch::lazy::fill_(self_tensor, value); |
323 | return self; |
324 | } |
325 | |
326 | at::Tensor LazyNativeFunctions::max_pool3d( |
327 | const at::Tensor& self, |
328 | at::IntArrayRef kernel_size, |
329 | at::IntArrayRef stride, |
330 | at::IntArrayRef padding, |
331 | at::IntArrayRef dilation, |
332 | bool ceil_mode) { |
333 | return torch::lazy::MaxPool3dAutogradFunctionTS::apply( |
334 | self, kernel_size, stride, padding, dilation, ceil_mode); |
335 | } |
336 | |
337 | // We need to explicitly override max pooling operators and just call the |
338 | // fallback for them because we've customized the autograd function for them |
339 | // (backward needs saved indices from forward). |
340 | std::tuple<at::Tensor, at::Tensor> LazyNativeFunctions::max_pool3d_with_indices( |
341 | const at::Tensor& self, |
342 | at::IntArrayRef kernel_size, |
343 | at::IntArrayRef stride, |
344 | at::IntArrayRef padding, |
345 | at::IntArrayRef dilation, |
346 | bool ceil_mode) { |
347 | return at::native:: |
348 | call_fallback_fn<<c_eager_fallback, ATEN_OP(max_pool3d_with_indices)>:: |
349 | call(self, kernel_size, stride, padding, dilation, ceil_mode); |
350 | } |
351 | |
352 | at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward( |
353 | const at::Tensor& grad_output, |
354 | const at::Tensor& self, |
355 | at::IntArrayRef kernel_size, |
356 | at::IntArrayRef stride, |
357 | at::IntArrayRef padding, |
358 | at::IntArrayRef dilation, |
359 | bool ceil_mode, |
360 | const at::Tensor& indices) { |
361 | return at::native::call_fallback_fn< |
362 | <c_eager_fallback, |
363 | ATEN_OP(max_pool3d_with_indices_backward)>:: |
364 | call( |
365 | grad_output, |
366 | self, |
367 | kernel_size, |
368 | stride, |
369 | padding, |
370 | dilation, |
371 | ceil_mode, |
372 | indices); |
373 | } |
374 | |
375 | at::Tensor& LazyNativeFunctions::normal_( |
376 | at::Tensor& self, |
377 | double mean, |
378 | double std, |
379 | c10::optional<at::Generator> generator) { |
380 | // Unconditionally fall back. |
381 | // implementing normal_ via lazy tensor caused differences in results compared |
382 | // to eager. |
383 | return at::native::call_fallback_fn<<c_eager_fallback, ATEN_OP(normal_)>:: |
384 | call(self, mean, std, generator); |
385 | |
386 | // if (force_eager_fallback(c10::Symbol::fromQualString("aten::normal_"))) { |
387 | // return at::native::call_fallback_fn<<c_eager_fallback, |
388 | // ATEN_OP(normal_)>::call(self, mean, std, generator); |
389 | // } |
390 | |
391 | // if (generator.has_value()) { |
392 | // return at::native::call_fallback_fn<<c_eager_fallback, |
393 | // ATEN_OP(normal_)>::call(self, mean, std, generator); |
394 | // } |
395 | |
396 | // TORCH_LAZY_FN_COUNTER("lazy::"); |
397 | // auto device = bridge::GetBackendDevice(self); |
398 | // LazyTensor lazy_self = GetLtcTensorOrCreateForWrappedNumber(self, *device); |
399 | // std::vector<torch::lazy::Shape> shapes = |
400 | // {torch::lazy::Shape(self.scalar_type(), self.sizes().vec())}; auto node = |
401 | // torch::lazy::MakeNode<Normal>(lazy_self.GetIrValue(), mean, std, |
402 | // std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self; |
403 | }; |
404 | |
405 | at::Tensor LazyNativeFunctions::_unsafe_view( |
406 | const at::Tensor& self, |
407 | at::IntArrayRef size) { |
408 | TORCH_LAZY_FN_COUNTER("lazy::" ); |
409 | return LazyNativeFunctions::view_copy_symint( |
410 | self, c10::fromIntArrayRefSlow(size)); |
411 | } |
412 | |
413 | // This is needed by the torch.tensor constructor. |
414 | // LazyTensor always opts into functionalization. |
415 | // "lifting" a tensor for functionalization means wrapping it in a |
416 | // FunctionalTensorWrapper object. |
417 | at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) { |
418 | TORCH_INTERNAL_ASSERT( |
419 | !at::functionalization::impl::isFunctionalTensor(tensor)); |
420 | return at::functionalization::impl::to_functional_tensor(tensor); |
421 | } |
422 | at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) { |
423 | TORCH_INTERNAL_ASSERT( |
424 | !at::functionalization::impl::isFunctionalTensor(tensor)); |
425 | return at::functionalization::impl::to_functional_tensor(tensor); |
426 | } |
427 | |
428 | // All of the below ops correspond to CompositeExplicitAutograd kernels from |
429 | // core that call into view operators internally. These are all composite ops |
430 | // that LTC can technically re-use / get for free, but we need to |
431 | // "functionalize" them to remove the view ops before we can use them. |
432 | at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { |
433 | return at::functionalization::functionalize_aten_op<ATEN_OP( |
434 | block_diag)>::call(tensors); |
435 | } |
436 | at::Tensor LazyNativeFunctions::new_empty_strided_symint( |
437 | const at::Tensor& self, |
438 | c10::SymIntArrayRef size, |
439 | c10::SymIntArrayRef stride, |
440 | c10::optional<at::ScalarType> dtype, |
441 | c10::optional<at::Layout> layout, |
442 | c10::optional<at::Device> device, |
443 | c10::optional<bool> pin_memory) { |
444 | return at::functionalization:: |
445 | functionalize_aten_op_symint<ATEN_OP(new_empty_strided)>::call( |
446 | self, size, stride, dtype, layout, device, pin_memory); |
447 | } |
448 | |
449 | at::Tensor LazyNativeFunctions::narrow_copy_symint( |
450 | const at::Tensor& self, |
451 | int64_t dim, |
452 | c10::SymInt start, |
453 | c10::SymInt length) { |
454 | return at::functionalization::functionalize_aten_op_symint<ATEN_OP( |
455 | narrow_copy)>::call(self, dim, start, length); |
456 | } |
457 | at::Tensor LazyNativeFunctions::pixel_shuffle( |
458 | const at::Tensor& self, |
459 | int64_t upscale_factor) { |
460 | return at::functionalization::functionalize_aten_op<ATEN_OP( |
461 | pixel_shuffle)>::call(self, upscale_factor); |
462 | } |
463 | at::Tensor LazyNativeFunctions::pixel_unshuffle( |
464 | const at::Tensor& self, |
465 | int64_t downscale_factor) { |
466 | return at::functionalization::functionalize_aten_op<ATEN_OP( |
467 | pixel_unshuffle)>::call(self, downscale_factor); |
468 | } |
469 | at::Tensor LazyNativeFunctions::select_backward_symint( |
470 | const at::Tensor& grad_output, |
471 | c10::SymIntArrayRef input_sizes, |
472 | int64_t dim, |
473 | c10::SymInt index) { |
474 | return at::functionalization::functionalize_aten_op_symint<ATEN_OP( |
475 | select_backward)>::call(grad_output, input_sizes, dim, index); |
476 | } |
477 | at::Tensor LazyNativeFunctions::_trilinear( |
478 | const at::Tensor& i1, |
479 | const at::Tensor& i2, |
480 | const at::Tensor& i3, |
481 | at::IntArrayRef expand1, |
482 | at::IntArrayRef expand2, |
483 | at::IntArrayRef expand3, |
484 | at::IntArrayRef sumdim, |
485 | int64_t unroll_dim) { |
486 | return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>:: |
487 | call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); |
488 | } |
489 | at::Tensor LazyNativeFunctions::linalg_pinv( |
490 | const at::Tensor& self, |
491 | const c10::optional<at::Tensor>& atol, |
492 | const c10::optional<at::Tensor>& rtol, |
493 | bool hermitian) { |
494 | return at::functionalization::functionalize_aten_op<ATEN_OP2( |
495 | linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian); |
496 | } |
497 | |
498 | // functionalize_aten_op can't handle out= ops directly. |
499 | // Instead, we can call the composite kernel from core, and copy and mutations |
500 | // back to the inputs. |
501 | at::Tensor& LazyNativeFunctions::logsumexp_out( |
502 | const at::Tensor& self, |
503 | at::IntArrayRef dim, |
504 | bool keepdim, |
505 | at::Tensor& out) { |
506 | auto self_wrapped = at::functionalization::impl::to_functional_tensor(self); |
507 | auto out_wrapped = at::functionalization::impl::to_functional_tensor(out); |
508 | // directly call the composite kernel from core. |
509 | // Make sure to re-enable functionalization first. |
510 | auto curr_tls = c10::impl::tls_local_dispatch_key_set(); |
511 | auto tls_reenable_functionalize = c10::impl::PODLocalDispatchKeySet(); |
512 | tls_reenable_functionalize.set_included(curr_tls.included_); |
513 | tls_reenable_functionalize.set_excluded( |
514 | curr_tls.excluded_.remove(c10::DispatchKey::Functionalize)); |
515 | c10::impl::ForceDispatchKeyGuard guard_(tls_reenable_functionalize); |
516 | at::native::logsumexp_out(self_wrapped, dim, keepdim, out_wrapped); |
517 | auto out_unwrapped = |
518 | at::functionalization::impl::from_functional_tensor(out_wrapped); |
519 | // propagate mutations back to the inputs (including resizing) |
520 | out.resize_(out_unwrapped.sizes()); |
521 | out.copy_(out_unwrapped); |
522 | return out; |
523 | } |
524 | |
525 | at::Tensor LazyNativeFunctions::diag_embed( |
526 | const at::Tensor& self, |
527 | int64_t offset, |
528 | int64_t dim1, |
529 | int64_t dim2) { |
530 | return at::functionalization::functionalize_aten_op<ATEN_OP( |
531 | diag_embed)>::call(self, offset, dim1, dim2); |
532 | } |
533 | |
534 | at::Tensor LazyNativeFunctions::diagonal_backward_symint( |
535 | const at::Tensor& grad_output, |
536 | at::SymIntArrayRef input_sizes, |
537 | int64_t offset, |
538 | int64_t dim1, |
539 | int64_t dim2) { |
540 | return at::functionalization::functionalize_aten_op_symint<ATEN_OP( |
541 | diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2); |
542 | } |
543 | |
544 | at::Tensor LazyNativeFunctions::slice_backward_symint( |
545 | const at::Tensor& grad_output, |
546 | at::SymIntArrayRef input_sizes, |
547 | int64_t dim, |
548 | c10::SymInt start, |
549 | c10::SymInt end, |
550 | c10::SymInt step) { |
551 | return at::functionalization::functionalize_aten_op_symint<ATEN_OP( |
552 | slice_backward)>::call(grad_output, input_sizes, dim, start, end, step); |
553 | } |
554 | |
555 | // re-use the composite kernel from core, that way we don't need to provide a |
556 | // backwards formula for native_group_norm |
557 | std::tuple<Tensor, Tensor, Tensor> LazyNativeFunctions::native_group_norm( |
558 | const at::Tensor& input, |
559 | const c10::optional<at::Tensor>& weight, |
560 | const c10::optional<at::Tensor>& bias, |
561 | int64_t N, |
562 | int64_t C, |
563 | int64_t HxW, |
564 | int64_t group, |
565 | double eps) { |
566 | return at::native::math_group_norm( |
567 | input, weight, bias, N, C, HxW, group, eps); |
568 | } |
569 | |
570 | void InitializeAtenBindings() {} |
571 | |
572 | } // namespace lazy |
573 | } // namespace torch |
574 | |