1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // The Stream is used in conjunction with the StreamExecutor "parent" to |
17 | // perform actions with a linear stream of dependencies. Dependencies can also |
18 | // be created between Streams to do task management (i.e. limit which tasks |
19 | // can be performed concurrently and specify what task dependencies exist). |
20 | |
21 | #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_ |
22 | #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_ |
23 | |
24 | #include <complex> |
25 | #include <cstdint> |
26 | #include <functional> |
27 | #include <memory> |
28 | #include <type_traits> |
29 | |
30 | #include "absl/base/thread_annotations.h" |
31 | #include "absl/synchronization/mutex.h" |
32 | #include "tensorflow/compiler/xla/stream_executor/blas.h" |
33 | #include "tensorflow/compiler/xla/stream_executor/device_memory.h" |
34 | #include "tensorflow/compiler/xla/stream_executor/dnn.h" |
35 | #include "tensorflow/compiler/xla/stream_executor/event.h" |
36 | #include "tensorflow/compiler/xla/stream_executor/fft.h" |
37 | #include "tensorflow/compiler/xla/stream_executor/kernel.h" |
38 | #include "tensorflow/compiler/xla/stream_executor/launch_dim.h" |
39 | #include "tensorflow/compiler/xla/stream_executor/lib/array_slice.h" |
40 | #include "tensorflow/compiler/xla/stream_executor/platform/port.h" |
41 | #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" |
42 | #include "tensorflow/compiler/xla/stream_executor/temporary_memory_manager.h" |
43 | |
44 | namespace stream_executor { |
45 | |
46 | namespace host { |
47 | class HostBlas; |
48 | class HostFft; |
49 | class HostRng; |
50 | class HostTimer; |
51 | } // namespace host |
52 | |
53 | namespace ocl { |
54 | class CLBlas; |
55 | } // namespace ocl |
56 | |
57 | namespace internal { |
58 | class StreamInterface; |
59 | } // namespace internal |
60 | |
61 | class DeviceMemoryBase; |
62 | template <typename ElemT> |
63 | class DeviceMemory; |
64 | |
65 | class Timer; |
66 | |
67 | namespace dnn { |
68 | class BatchDescriptor; |
69 | class FilterDescriptor; |
70 | class ConvolutionDescriptor; |
71 | class ProfileResult; |
72 | class AlgorithmDesc; |
73 | } // namespace dnn |
74 | |
75 | class StreamExecutor; |
76 | class ScratchAllocator; |
77 | |
78 | namespace detail { |
79 | |
80 | // Helper class to prevent a template function argument from being deduced. This |
81 | // is identical to std::type_identity in C++20. |
82 | template <typename T> |
83 | struct NonDeduced { |
84 | using type = T; |
85 | }; |
86 | template <typename T> |
87 | using NonDeducedType = typename NonDeduced<T>::type; |
88 | |
89 | // Helper to return if `T` is the same type as `First` or any or `Rest`. |
90 | template <typename T> |
91 | constexpr bool is_any_of() { |
92 | return false; |
93 | } |
94 | |
95 | template <typename T, typename First, typename... Rest> |
96 | constexpr bool is_any_of() { |
97 | return std::is_same_v<T, First> || is_any_of<T, Rest...>(); |
98 | } |
99 | |
100 | } // namespace detail |
101 | |
102 | // Convert a type to the corresponding QuantizedActivationMode. |
103 | template <typename ElementType> |
104 | struct Quantization; |
105 | |
106 | // Represents a stream of dependent computations on a GPU device. |
107 | // |
108 | // The operations within a stream execute linearly and asynchronously until |
109 | // BlockHostUntilDone() is invoked, which synchronously joins host code with |
110 | // the execution of the stream. |
111 | // |
112 | // If any given operation fails when entraining work for the stream, ok() will |
113 | // indicate that an error has occurred. After initialization, once a stream is |
114 | // !ok(), it will never be ok(). |
115 | // |
116 | // Thread-safe post-initialization. |
117 | class Stream { |
118 | public: |
119 | // Instantiate a stream tied to parent as a platform executor. Work |
120 | // entrained onto this stream will be launched/managed on that |
121 | // StreamExecutor's platform. |
122 | explicit Stream(StreamExecutor *parent); |
123 | |
124 | // Deallocates any stream resources that the parent StreamExecutor has |
125 | // bestowed |
126 | // upon this object. |
127 | ~Stream(); |
128 | |
129 | // Returns whether any errors have occurred while entraining work for this |
130 | // stream. |
131 | bool ok() const { return !InErrorState(); } |
132 | |
133 | // Retrieves execution status back into the stream from the underlying |
134 | // implementation without blocking the stream. |
135 | // |
136 | // Normally, Stream::BlockHostUntilDone is used to get execution status. |
137 | // However, some devices use out-of-band mechnanisms to ensure their streams |
138 | // have finished on-device work, without needing to block the streams. (These |
139 | // devices should also override AllowsSyncOnCompletion to return false.) For |
140 | // these devices, this method can be used after work is finished to retrieve |
141 | // execution status. |
142 | port::Status RefreshStatus() TF_LOCKS_EXCLUDED(mu_); |
143 | |
144 | // Initialize the stream. This must be performed before entraining any other |
145 | // operations. |
146 | Stream &Init() TF_LOCKS_EXCLUDED(mu_); |
147 | |
148 | // Initializes timer t via the StreamExecutor. |
149 | Stream &InitTimer(Timer *t); |
150 | |
151 | // Convenience wrapper around Init() and InitTimer(). |
152 | Stream &InitWithTimer(Timer *t); |
153 | |
154 | // Get or create a sub-stream from this stream. If there is any sub-stream in |
155 | // the pool that can be reused then just return this sub-stream. Otherwise |
156 | // create a new sub-stream. |
157 | // |
158 | // TODO(b/112196569): The semantics of failed sub-streams is error-prone. |
159 | Stream *GetOrCreateSubStream() TF_LOCKS_EXCLUDED(mu_); |
160 | |
161 | // Return the sub-stream back to the host stream so that it can be reused |
162 | // later. Sub-streams that are !ok() will not be reused. |
163 | // |
164 | // TODO(b/112196569): The semantics of failed sub-streams is error-prone. |
165 | void ReturnSubStream(Stream *sub_stream) TF_LOCKS_EXCLUDED(mu_); |
166 | |
167 | // Allocate temporary memories. The stream will deallocate them when blocked |
168 | // or destroyed. |
169 | template <typename T> |
170 | port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> |
171 | AllocateTemporaryArray(uint64_t element_count); |
172 | |
173 | // Entrains onto the stream of operations: a kernel launch with the given |
174 | // (variadic) parameters for the invocation. These arguments can be things |
175 | // like DeviceMemory or primitive types such as int. What arguments you may |
176 | // pass to a given kernel are noted as the template parameters to the |
177 | // TypedKernel type that the machocc compiler generates. |
178 | // |
179 | // Template parameters: |
180 | // Params... The type list of formal parameters that the typed kernel |
181 | // expects, which is matched against Args... |
182 | // Args... The deduced type list for passed actual arguments |
183 | // |
184 | // Implementation: A compile-time compatibility check is performed that has |
185 | // some leniency versus an exact parameter pack match -- for example, |
186 | // `const DeviceMemory<T>` is considered "pack compatible" with a |
187 | // `const DeviceMemory<T>&` formal parameter; in part, because we don't have |
188 | // perfect forwarding support without rvalue references. It also attempts to |
189 | // spit out helpful static_assert error traces with information as to the |
190 | // argument number and types that were mismatched. |
191 | template <typename... Params, typename... Args> |
192 | port::Status ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, |
193 | const TypedKernel<Params...> &kernel, Args... args); |
194 | |
195 | // Record a "start" event for the interval timer at this point in the |
196 | // stream's execution (relative to the previously and subsequently enqueued |
197 | // items in the stream's execution). Streams may be started/stopped multiple |
198 | // times. |
199 | Stream &ThenStartTimer(Timer *t); |
200 | |
201 | // Record a "stop" event for the interval timer at this point in the |
202 | // stream's execution. See also Stream::ThenStartTimer. |
203 | Stream &ThenStopTimer(Timer *t); |
204 | |
205 | // TODO(leary) If work is added to the stream that is being depended upon, |
206 | // then what? Have to describe what happens. |
207 | template <typename... Params> |
208 | Stream &ThenWaitFor(Stream *other, Params... more_streams) { |
209 | return ThenWaitFor(more_streams...).ThenWaitFor(other); |
210 | } |
211 | |
212 | // Create a dependency for this stream's next work on the other stream |
213 | // completing. Does not take ownership of other, and other must not be |
214 | // null. |
215 | // |
216 | // Checks that a stream does not wait for itself, and it is up to the |
217 | // user to guarantee that a stream does not come to wait on itself in a |
218 | // cyclic manner; in that case, behavior is undefined. |
219 | // |
220 | // N.B. Base recursion case for the variadic ThenWaitFor. |
221 | Stream &ThenWaitFor(Stream *other); |
222 | |
223 | // Waits for all streams values in others. |
224 | // Checks that there is no shallow circular wait (i.e. that "this" is not in |
225 | // others) |
226 | template <typename P> |
227 | Stream &ThenWaitFor(P others) { |
228 | for (auto &stream : *others) { |
229 | CHECK_NE(stream.get(), this); |
230 | ThenWaitFor(stream.get()); |
231 | } |
232 | return *this; |
233 | } |
234 | |
235 | // Waits for an event object to be set. |
236 | // Note that ThenRecordEvent must have been called on the event before |
237 | // you call this function; otherwise the event will be considered complete |
238 | // and this wait will do nothing. |
239 | Stream &ThenWaitFor(Event *event); |
240 | |
241 | // Inserts the specified event into the end of this stream. Once the stream |
242 | // has processed all events prior to the insertion point, the event will be |
243 | // marked as completed. |
244 | // The stream does not take ownership of event - meaning that event's lifetime |
245 | // must extend past the point at which it is marked complete! |
246 | Stream &ThenRecordEvent(Event *event); |
247 | |
248 | //////////////// |
249 | // DNN support |
250 | // |
251 | // See DnnSupport::* for comments on the following methods. |
252 | |
253 | Stream &ThenBatchNormalizationForward( |
254 | const DeviceMemory<float> &x, const DeviceMemory<float> &scale, |
255 | const DeviceMemory<float> &offset, |
256 | const DeviceMemory<float> &estimated_mean, |
257 | const DeviceMemory<float> &estimated_variance, |
258 | const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc, |
259 | const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
260 | const double exponential_average_factor, |
261 | dnn::ActivationMode activation_mode, DeviceMemory<float> *y, |
262 | DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var, |
263 | DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var, |
264 | bool is_training, ScratchAllocator *reserve_space_allocator, |
265 | ScratchAllocator *workspace_allocator); |
266 | |
267 | Stream &ThenBatchNormalizationBackward( |
268 | const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x, |
269 | const DeviceMemory<float> &scale, const DeviceMemory<float> &offset, |
270 | const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var, |
271 | const DeviceMemory<float> &y, const dnn::BatchDescriptor &x_desc, |
272 | const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
273 | dnn::ActivationMode activation_mode, DeviceMemory<float> *x_backprop, |
274 | DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop, |
275 | DeviceMemory<float> *side_input_backprop, |
276 | DeviceMemory<uint8_t> *reserve_space_data, |
277 | ScratchAllocator *workspace_allocator); |
278 | |
279 | Stream &ThenBatchNormalizationForward( |
280 | const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, |
281 | const DeviceMemory<float> &offset, |
282 | const DeviceMemory<float> &estimated_mean, |
283 | const DeviceMemory<float> &estimated_variance, |
284 | const DeviceMemory<Eigen::half> &side_input, |
285 | const dnn::BatchDescriptor &x_desc, |
286 | const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
287 | const double exponential_average_factor, |
288 | dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y, |
289 | DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var, |
290 | DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var, |
291 | bool is_training, ScratchAllocator *reserve_space_allocator, |
292 | ScratchAllocator *workspace_allocator); |
293 | |
294 | Stream &ThenBatchNormalizationBackward( |
295 | const DeviceMemory<Eigen::half> &y_backprop, |
296 | const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, |
297 | const DeviceMemory<float> &offset, const DeviceMemory<float> &mean, |
298 | const DeviceMemory<float> &inv_var, const DeviceMemory<Eigen::half> &y, |
299 | const dnn::BatchDescriptor &x_desc, |
300 | const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, |
301 | dnn::ActivationMode activation_mode, |
302 | DeviceMemory<Eigen::half> *x_backprop, |
303 | DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop, |
304 | DeviceMemory<Eigen::half> *side_input_backprop, |
305 | DeviceMemory<uint8_t> *reserve_space_data, |
306 | ScratchAllocator *workspace_allocator); |
307 | |
308 | Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, |
309 | const DeviceMemory<float> &input_data, |
310 | const dnn::FilterDescriptor &filter_descriptor, |
311 | const DeviceMemory<float> &filter_data, |
312 | const dnn::ConvolutionDescriptor &convolution_descriptor, |
313 | const dnn::BatchDescriptor &output_descriptor, |
314 | DeviceMemory<float> *output); |
315 | |
316 | Stream &ThenConvolveQuantized( |
317 | const dnn::BatchDescriptor &input_descriptor, |
318 | const DeviceMemory<float> &input_data, |
319 | const dnn::FilterDescriptor &filter_descriptor, |
320 | const DeviceMemory<int8_t> &filter_coefficients, |
321 | const DeviceMemory<float> &coefficient_scales, |
322 | const dnn::ConvolutionDescriptor &convolution_descriptor, |
323 | const dnn::BatchDescriptor &output_descriptor, |
324 | DeviceMemory<float> *output_data); |
325 | |
326 | Stream &ThenConvolveQuantized( |
327 | const dnn::BatchDescriptor &input_descriptor, |
328 | const DeviceMemory<float> &input_data, |
329 | const dnn::FilterDescriptor &filter_descriptor, |
330 | const DeviceMemory<int16> &filter_coefficients, |
331 | const DeviceMemory<float> &coefficient_scales, |
332 | const dnn::ConvolutionDescriptor &convolution_descriptor, |
333 | const dnn::BatchDescriptor &output_descriptor, |
334 | DeviceMemory<float> *output_data); |
335 | |
336 | template <typename InputType, typename OutputType> |
337 | port::Status ConvolveWithAlgorithm( |
338 | dnn::ConvolutionKind kind, const dnn::BatchDescriptor &input_descriptor, |
339 | DeviceMemory<InputType> input_data, |
340 | const dnn::FilterDescriptor &filter_descriptor, |
341 | DeviceMemory<InputType> filter_data, |
342 | const dnn::BatchDescriptor &output_descriptor, |
343 | DeviceMemory<OutputType> output_data, |
344 | const dnn::ConvolutionDescriptor &convolution_descriptor, |
345 | ScratchAllocator *scratch_allocator, |
346 | const dnn::AlgorithmConfig &algorithm_config, |
347 | dnn::ProfileResult *output_profile_result) { |
348 | DeviceMemory<uint8_t> scratch_memory; |
349 | dnn::AlgorithmDesc algorithm_desc; |
350 | if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
351 | TF_RETURN_IF_ERROR(dnn->PrepareForConvolution( |
352 | kind, this, input_descriptor, input_data, filter_descriptor, |
353 | filter_data, output_descriptor, output_data, convolution_descriptor, |
354 | algorithm_config, scratch_allocator, &algorithm_desc, |
355 | &scratch_memory)); |
356 | return dnn->DoConvolve(kind, dnn::ToDataType<InputType>::value, |
357 | dnn::ToDataType<OutputType>::value, this, |
358 | input_descriptor, input_data, filter_descriptor, |
359 | filter_data, output_descriptor, output_data, |
360 | convolution_descriptor, algorithm_desc, |
361 | scratch_memory, output_profile_result); |
362 | } |
363 | return port::UnimplementedError("DNN library is not found." ); |
364 | } |
365 | |
366 | template <typename InputT, typename ScaleT, typename SideInputT, |
367 | typename BiasT, typename OutputT> |
368 | port::Status FusedConvolveWithAlgorithm( |
369 | const dnn::BatchDescriptor &conv_input_descriptor, |
370 | const DeviceMemory<InputT> &conv_input_data, ScaleT conv_input_scale, |
371 | const dnn::FilterDescriptor &filter_descriptor, |
372 | const DeviceMemory<InputT> &filter_data, |
373 | const dnn::ConvolutionDescriptor &convolution_descriptor, |
374 | const DeviceMemory<SideInputT> &side_input_data, ScaleT side_input_scale, |
375 | const dnn::BatchDescriptor &bias_descriptor, |
376 | const DeviceMemory<BiasT> &biases, dnn::ActivationMode activation_mode, |
377 | const dnn::BatchDescriptor &output_descriptor, |
378 | DeviceMemory<OutputT> *output, ScratchAllocator *scratch_allocator, |
379 | const dnn::AlgorithmConfig &algorithm_config, |
380 | dnn::ProfileResult *output_profile_result) { |
381 | if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
382 | return dnn->DoFusedConvolve( |
383 | this, dnn::ToDataType<InputT>::value, |
384 | dnn::ToDataType<SideInputT>::value, dnn::ToDataType<BiasT>::value, |
385 | dnn::ToDataType<OutputT>::value, conv_input_descriptor, |
386 | conv_input_data, conv_input_scale, filter_descriptor, filter_data, |
387 | convolution_descriptor, side_input_data, side_input_scale, |
388 | bias_descriptor, biases, activation_mode, output_descriptor, *output, |
389 | scratch_allocator, algorithm_config, output_profile_result); |
390 | } |
391 | return port::UnimplementedError("DNN library is not found." ); |
392 | } |
393 | |
394 | port::StatusOr<std::unique_ptr<const dnn::ConvRunner>> ConvolveRunnerFromDesc( |
395 | const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, |
396 | dnn::DataType element_type, dnn::DataType output_type, |
397 | const dnn::BatchDescriptor &input_descriptor, |
398 | const dnn::FilterDescriptor &filter_descriptor, |
399 | const dnn::BatchDescriptor &output_descriptor, |
400 | const dnn::ConvolutionDescriptor &convolution_descriptor) { |
401 | dnn::DnnSupport *dnn_support = parent_->AsDnn(); |
402 | if (!dnn_support) { |
403 | return port::UnimplementedError("DNN library is not found." ); |
404 | } |
405 | return dnn_support->ConvolveRunnerFromDesc( |
406 | this, algorithm_desc, kind, element_type, output_type, input_descriptor, |
407 | filter_descriptor, output_descriptor, convolution_descriptor); |
408 | } |
409 | |
410 | port::StatusOr<std::unique_ptr<const dnn::FusedConvRunner>> |
411 | FusedConvolveRunnerFromDesc( |
412 | const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, |
413 | dnn::DataType element_type, dnn::DataType bias_type, |
414 | dnn::DataType output_type, double conv_input_scale, |
415 | double side_input_scale, double leakyrelu_alpha, |
416 | const dnn::BatchDescriptor &input_descriptor, |
417 | const dnn::FilterDescriptor &filter_descriptor, |
418 | const dnn::BatchDescriptor &bias_descriptor, |
419 | const dnn::BatchDescriptor &output_descriptor, |
420 | const dnn::ConvolutionDescriptor &convolution_descriptor, |
421 | dnn::ActivationMode activation_mode) { |
422 | dnn::DnnSupport *dnn_support = parent_->AsDnn(); |
423 | if (!dnn_support) { |
424 | return port::UnimplementedError("DNN library is not found." ); |
425 | } |
426 | return dnn_support->FusedConvolveRunnerFromDesc( |
427 | this, algorithm_desc, kind, element_type, bias_type, output_type, |
428 | conv_input_scale, side_input_scale, leakyrelu_alpha, input_descriptor, |
429 | filter_descriptor, bias_descriptor, output_descriptor, |
430 | convolution_descriptor, activation_mode); |
431 | } |
432 | |
433 | Stream &ThenSeparableConvolve( |
434 | const dnn::BatchDescriptor &input_descriptor, |
435 | const DeviceMemory<float> &input_data, |
436 | const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier, |
437 | const DeviceMemory<float> &first_weights, |
438 | const DeviceMemory<float> &second_weights, |
439 | const dnn::ConvolutionDescriptor &convolution_descriptor, |
440 | const dnn::BatchDescriptor &output_descriptor, |
441 | DeviceMemory<float> *output); |
442 | |
443 | Stream &ThenMatMul(const DeviceMemory<float> &input_data, |
444 | const DeviceMemory<float> &weights, |
445 | const dnn::BatchDescriptor &input_dimensions, |
446 | const dnn::BatchDescriptor &output_dimensions, |
447 | DeviceMemory<float> *output_data); |
448 | |
449 | Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, |
450 | const DeviceMemory<int8_t> &weights, |
451 | const DeviceMemory<float> &weight_scales, |
452 | const dnn::BatchDescriptor &input_dimensions, |
453 | const dnn::BatchDescriptor &output_dimensions, |
454 | DeviceMemory<float> *output_data); |
455 | |
456 | Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, |
457 | const DeviceMemory<int16> &weights, |
458 | const DeviceMemory<float> &weight_scales, |
459 | const dnn::BatchDescriptor &input_dimensions, |
460 | const dnn::BatchDescriptor &output_dimensions, |
461 | DeviceMemory<float> *output_data); |
462 | |
463 | Stream &ThenBiasAdd(const DeviceMemory<float> &input_data, |
464 | const DeviceMemory<float> &biases, |
465 | const dnn::BatchDescriptor &dimensions, |
466 | DeviceMemory<float> *output_data); |
467 | |
468 | template <typename ElementType> |
469 | port::Status ThenPoolForward( |
470 | const dnn::PoolingDescriptor &pooling_dimensions, |
471 | const dnn::BatchDescriptor &input_dimensions, |
472 | const DeviceMemory<ElementType> &input_data, |
473 | const dnn::BatchDescriptor &output_dimensions, |
474 | DeviceMemory<ElementType> *output_data, |
475 | ScratchAllocator *workspace_allocator = nullptr) { |
476 | if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
477 | return dnn->DoPoolForward(dnn::ToDataType<ElementType>::value, this, |
478 | pooling_dimensions, input_dimensions, |
479 | input_data, output_dimensions, *output_data, |
480 | workspace_allocator); |
481 | } |
482 | return port::UnimplementedError("DNN library is not found." ); |
483 | } |
484 | |
485 | template <typename ElementType> |
486 | port::Status ThenPoolBackward( |
487 | const dnn::PoolingDescriptor &pooling_dimensions, |
488 | const dnn::BatchDescriptor &input_dimensions, |
489 | const DeviceMemory<ElementType> &input_data, |
490 | const dnn::BatchDescriptor &output_dimensions, |
491 | const DeviceMemory<ElementType> &output_data, |
492 | const DeviceMemory<ElementType> &input_diff_data, |
493 | DeviceMemory<ElementType> *output_diff_data, |
494 | ScratchAllocator *workspace_allocator = nullptr) { |
495 | if (dnn::DnnSupport *dnn = parent_->AsDnn()) { |
496 | return dnn->DoPoolBackward( |
497 | dnn::ToDataType<ElementType>::value, this, pooling_dimensions, |
498 | input_dimensions, input_data, output_dimensions, output_data, |
499 | input_diff_data, *output_diff_data, workspace_allocator); |
500 | } |
501 | return port::UnimplementedError("DNN library is not found." ); |
502 | } |
503 | |
504 | Stream &ThenNormalizeWithDimensions( |
505 | const dnn::NormalizeDescriptor &normalize_descriptor, |
506 | const dnn::BatchDescriptor &dimensions, |
507 | const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data); |
508 | |
509 | Stream &ThenNormalizeBackwardWithDimensions( |
510 | const dnn::NormalizeDescriptor &normalize_descriptor, |
511 | const dnn::BatchDescriptor &dimensions, |
512 | const DeviceMemory<float> &raw_data, |
513 | const DeviceMemory<float> &normalized_data, |
514 | const DeviceMemory<float> &normalized_variable_gradient, |
515 | DeviceMemory<float> *raw_variable_gradient, |
516 | ScratchAllocator *workspace_allocator = nullptr); |
517 | |
518 | Stream &ThenActivate(dnn::ActivationMode activation_mode, |
519 | const dnn::BatchDescriptor &dimensions, |
520 | const DeviceMemory<float> &input_data, |
521 | DeviceMemory<float> *output_data); |
522 | |
523 | // Same as ThenActivate, but also takes an options argument that can be used |
524 | // for platform-specific option flags. |
525 | Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode, |
526 | const dnn::BatchDescriptor &dimensions, |
527 | const DeviceMemory<float> &input_data, |
528 | DeviceMemory<float> *output_data, |
529 | uint64_t options); |
530 | |
531 | Stream &ThenDepthConcatenate( |
532 | port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok |
533 | port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok |
534 | DeviceMemory<float> *output_data); |
535 | |
536 | Stream &ThenSpaceConcatenate( |
537 | port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok |
538 | port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok |
539 | DeviceMemory<float> *output_data, |
540 | dnn::SpaceConcatenateMode concat_direction); |
541 | |
542 | // Change the layout of the data by shrinking one dimension (or set of |
543 | // dimensions) and growing another dimension (or set of dimensions), while |
544 | // keeping the total number of data elements constant, and maintaining the |
545 | // current data ordering. |
546 | Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions, |
547 | const DeviceMemory<float> &input_data, |
548 | const dnn::BatchDescriptor &output_dimensions, |
549 | DeviceMemory<float> *output_data); |
550 | |
551 | // Depth to space takes an X by Y image with depth D*M² and changes it to an |
552 | // MX x MY image with depth D. Each input location (x,y) with depth D*M² in |
553 | // the input image is changed to an MxM contiguous area in the output image, |
554 | // with the values being laid out in raster order specified by |
555 | // DepthToSpaceLayout, and will have a new depth of D. |
556 | // See the DoDepthToSpace comment for more information. |
557 | Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions, |
558 | const DeviceMemory<float> &input_data, |
559 | const dnn::DepthToSpaceLayout &depth_to_space_layout, |
560 | const int sqrt_depth_reduction, |
561 | DeviceMemory<float> *output_data); |
562 | |
563 | // Space to depth is the inverse of depth to space. Space to depth takes each |
564 | // non-overlapping M by M patch (in the X and Y dimensions) with depth D of |
565 | // the input, and transforms it to a 1 by 1 patch with depth D*M². If the |
566 | // input has size (MX, MY, D), the output has size (X, Y, D*M²). The number of |
567 | // data elements is not changed. |
568 | Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions, |
569 | const DeviceMemory<float> &input_data, |
570 | const dnn::DepthToSpaceLayout &space_to_depth_layout, |
571 | const int sqrt_depth_increase, |
572 | DeviceMemory<float> *output_data); |
573 | |
574 | Stream &ThenElementwiseOperate( |
575 | dnn::ElementwiseOperation operation, |
576 | port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok |
577 | port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok |
578 | const dnn::BatchDescriptor &output_dimensions, |
579 | DeviceMemory<float> *output_data); |
580 | |
581 | Stream &ThenElementwiseOperateScaledQuantized( |
582 | dnn::ElementwiseOperation operation, |
583 | port::ArraySlice<int> input_multiplicands, // non-absl ok |
584 | int output_divisor, |
585 | port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok |
586 | port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok |
587 | const dnn::BatchDescriptor &output_dimensions, |
588 | DeviceMemory<float> *output_data); |
589 | |
590 | Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions, |
591 | const DeviceMemory<float> &input_data, int64_t left_pad, |
592 | int64_t right_pad, int64_t top_pad, int64_t bottom_pad, |
593 | DeviceMemory<float> *output_data); |
594 | |
595 | Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions, |
596 | const DeviceMemory<float> &input_data, int64_t left_trim, |
597 | int64_t right_trim, int64_t top_trim, int64_t bottom_trim, |
598 | DeviceMemory<float> *output_data); |
599 | |
600 | // Grows the input tensor by replicating the X and Y dimensions. The batch and |
601 | // depth/feature_map dimensions are unchanged. Currently, the input tensor is |
602 | // limited to X=1 and Y=1. |
603 | Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions, |
604 | const DeviceMemory<float> &input_data, |
605 | int64_t replicate_x, int64_t replicate_y, |
606 | DeviceMemory<float> *output_data); |
607 | |
608 | // See DnnSupport::DoMemcpyD2HQuantized. |
609 | Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src, |
610 | dnn::QuantizedActivationMode mode, |
611 | void *host_dst, uint64_t size); |
612 | |
613 | // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice |
614 | // and uses the Quantization trait to call the generic version of |
615 | // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode. |
616 | template <typename ElementType> |
617 | Stream &ThenMemcpyD2HQuantized( |
618 | const DeviceMemory<float> &gpu_unquantized_src, |
619 | port::MutableArraySlice<ElementType> host_dst) { |
620 | return ThenMemcpyD2HQuantized( |
621 | gpu_unquantized_src, Quantization<ElementType>::kModeId, |
622 | host_dst.data(), host_dst.size() * sizeof(ElementType)); |
623 | } |
624 | |
625 | // See DnnSupport::DoMemcpyH2DQuantized. |
626 | Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64_t size, |
627 | dnn::QuantizedActivationMode mode, |
628 | DeviceMemory<float> *gpu_unquantized_dst); |
629 | |
630 | // Template version of ThenMemcpyH2DQuantized that takes an array slice |
631 | // and uses the Quantization trait to call the generic version of |
632 | // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode. |
633 | template <typename ElementType> |
634 | Stream &ThenMemcpyH2DQuantized( |
635 | port::ArraySlice<ElementType> host_src, // non-absl ok |
636 | DeviceMemory<float> *gpu_unquantized_dst) { |
637 | return ThenMemcpyH2DQuantized( |
638 | host_src.data(), host_src.size() * sizeof(ElementType), |
639 | Quantization<ElementType>::kModeId, gpu_unquantized_dst); |
640 | } |
641 | |
642 | // See DnnSupport::DoCopyHostBuffer2Device. |
643 | Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src, |
644 | DeviceMemory<float> *gpu_unquantized_dst); |
645 | |
646 | // See DnnSupport::DoCopyDevice2HostBuffer. |
647 | Stream &ThenCopyDevice2HostBuffer( |
648 | const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst); |
649 | |
650 | ///////////////// |
651 | // BLAS support |
652 | |
653 | // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is |
654 | // present in DeviceMemory, it must be an execution-time constant (i.e. a |
655 | // value |
656 | // that the stream does not change or populate during the course of |
657 | // execution). The value is effectively captured at stream-enqueue time. |
658 | Stream &ThenBlasAxpy(uint64_t elem_count, float alpha, |
659 | const DeviceMemory<float> &x, int incx, |
660 | DeviceMemory<float> *y, int incy); |
661 | Stream &ThenBlasAxpy(uint64_t elem_count, double alpha, |
662 | const DeviceMemory<double> &x, int incx, |
663 | DeviceMemory<double> *y, int incy); |
664 | Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<float> alpha, |
665 | const DeviceMemory<std::complex<float>> &x, int incx, |
666 | DeviceMemory<std::complex<float>> *y, int incy); |
667 | Stream &ThenBlasAxpy(uint64_t elem_count, std::complex<double> alpha, |
668 | const DeviceMemory<std::complex<double>> &x, int incx, |
669 | DeviceMemory<std::complex<double>> *y, int incy); |
670 | |
671 | // See BlasSupport::DoBlasCopy. |
672 | Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<float> &x, |
673 | int incx, DeviceMemory<float> *y, int incy); |
674 | Stream &ThenBlasCopy(uint64_t elem_count, const DeviceMemory<double> &x, |
675 | int incx, DeviceMemory<double> *y, int incy); |
676 | Stream &ThenBlasCopy(uint64_t elem_count, |
677 | const DeviceMemory<std::complex<float>> &x, int incx, |
678 | DeviceMemory<std::complex<float>> *y, int incy); |
679 | Stream &ThenBlasCopy(uint64_t elem_count, |
680 | const DeviceMemory<std::complex<double>> &x, int incx, |
681 | DeviceMemory<std::complex<double>> *y, int incy); |
682 | |
683 | // See BlasSupport::DoBlasScal. |
684 | Stream &ThenBlasScal(uint64_t elem_count, float alpha, DeviceMemory<float> *x, |
685 | int incx); |
686 | Stream &ThenBlasScal(uint64_t elem_count, double alpha, |
687 | DeviceMemory<double> *x, int incx); |
688 | Stream &ThenBlasScal(uint64_t elem_count, float alpha, |
689 | DeviceMemory<std::complex<float>> *x, int incx); |
690 | Stream &ThenBlasScal(uint64_t elem_count, double alpha, |
691 | DeviceMemory<std::complex<double>> *x, int incx); |
692 | Stream &ThenBlasScal(uint64_t elem_count, std::complex<float> alpha, |
693 | DeviceMemory<std::complex<float>> *x, int incx); |
694 | Stream &ThenBlasScal(uint64_t elem_count, std::complex<double> alpha, |
695 | DeviceMemory<std::complex<double>> *x, int incx); |
696 | |
697 | // See BlasSupport::DoBlasGemv. |
698 | Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, float alpha, |
699 | const DeviceMemory<float> &a, int lda, |
700 | const DeviceMemory<float> &x, int incx, float beta, |
701 | DeviceMemory<float> *y, int incy); |
702 | Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, |
703 | double alpha, const DeviceMemory<double> &a, int lda, |
704 | const DeviceMemory<double> &x, int incx, double beta, |
705 | DeviceMemory<double> *y, int incy); |
706 | Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, |
707 | std::complex<float> alpha, |
708 | const DeviceMemory<std::complex<float>> &a, int lda, |
709 | const DeviceMemory<std::complex<float>> &x, int incx, |
710 | std::complex<float> beta, |
711 | DeviceMemory<std::complex<float>> *y, int incy); |
712 | Stream &ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n, |
713 | std::complex<double> alpha, |
714 | const DeviceMemory<std::complex<double>> &a, int lda, |
715 | const DeviceMemory<std::complex<double>> &x, int incx, |
716 | std::complex<double> beta, |
717 | DeviceMemory<std::complex<double>> *y, int incy); |
718 | |
719 | Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n, |
720 | float alpha, const DeviceMemory<float> &a, |
721 | int lda, const DeviceMemory<float> &x, |
722 | int incx, float beta, |
723 | DeviceMemory<float> *y, int incy, |
724 | blas::ProfileResult *output_profile_result); |
725 | Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64_t m, uint64 n, |
726 | double alpha, const DeviceMemory<double> &a, |
727 | int lda, const DeviceMemory<double> &x, |
728 | int incx, double beta, |
729 | DeviceMemory<double> *y, int incy, |
730 | blas::ProfileResult *output_profile_result); |
731 | Stream &ThenBlasGemvWithProfiling( |
732 | blas::Transpose trans, uint64_t m, uint64 n, std::complex<float> alpha, |
733 | const DeviceMemory<std::complex<float>> &a, int lda, |
734 | const DeviceMemory<std::complex<float>> &x, int incx, |
735 | std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, |
736 | blas::ProfileResult *output_profile_result); |
737 | Stream &ThenBlasGemvWithProfiling( |
738 | blas::Transpose trans, uint64_t m, uint64 n, std::complex<double> alpha, |
739 | const DeviceMemory<std::complex<double>> &a, int lda, |
740 | const DeviceMemory<std::complex<double>> &x, int incx, |
741 | std::complex<double> beta, DeviceMemory<std::complex<double>> *y, |
742 | int incy, blas::ProfileResult *output_profile_result); |
743 | |
744 | // See BlasSupport::DoBlasSbmv. |
745 | Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, float alpha, |
746 | const DeviceMemory<float> &a, int lda, |
747 | const DeviceMemory<float> &x, int incx, float beta, |
748 | DeviceMemory<float> *y, int incy); |
749 | Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k, |
750 | double alpha, const DeviceMemory<double> &a, int lda, |
751 | const DeviceMemory<double> &x, int incx, double beta, |
752 | DeviceMemory<double> *y, int incy); |
753 | |
754 | template <typename InputType> |
755 | port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
756 | uint64_t m, uint64 n, uint64 k, |
757 | const DeviceMemory<InputType> &a, int lda, |
758 | const DeviceMemory<InputType> &b, int ldb, |
759 | DeviceMemory<InputType> *c, int ldc, |
760 | blas::ComputePrecision precision) { |
761 | InputType alpha{1.0}; |
762 | InputType beta{0.0}; |
763 | return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, |
764 | ldc, precision); |
765 | } |
766 | |
767 | // TODO(parkers): Update all callers to pass kDefaultComputePrecision. |
768 | template <typename InputType> |
769 | port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
770 | uint64_t m, uint64 n, uint64 k, |
771 | const DeviceMemory<InputType> &a, int lda, |
772 | const DeviceMemory<InputType> &b, int ldb, |
773 | DeviceMemory<InputType> *c, int ldc) { |
774 | return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc, |
775 | blas::kDefaultComputePrecision); |
776 | } |
777 | |
778 | template <typename InputType, typename ConstantType> |
779 | port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
780 | uint64_t m, uint64 n, uint64 k, ConstantType alpha, |
781 | const DeviceMemory<InputType> &a, int lda, |
782 | const DeviceMemory<InputType> &b, int ldb, |
783 | ConstantType beta, DeviceMemory<InputType> *c, |
784 | int ldc, blas::ComputePrecision precision) { |
785 | static_assert( |
786 | detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16, float, |
787 | double, std::complex<float>, std::complex<double>>(), |
788 | "Input can be half, bf16, float, double, std::complex<float> or " |
789 | "std::complex<double>" ); |
790 | static_assert(!std::is_same_v<InputType, Eigen::half> || |
791 | detail::is_any_of<ConstantType, float, Eigen::half>(), |
792 | "If input is Eigen::half, constant has to be either " |
793 | "Eigen::half or float" ); |
794 | static_assert( |
795 | detail::is_any_of<InputType, Eigen::half, ConstantType>(), |
796 | "If input is not Eigen::half, constant and input types have to match" ); |
797 | blas::BlasSupport *blas = parent()->AsBlas(); |
798 | if (!blas) { |
799 | return port::InternalError( |
800 | "Attempting to perform BLAS operation using " |
801 | "StreamExecutor without BLAS support" ); |
802 | } |
803 | |
804 | void *alpha_ptr = α |
805 | void *beta_ptr = β |
806 | float alpha_storage, beta_storage; |
807 | UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
808 | &beta_storage); |
809 | |
810 | return blas->DoBlasGemm(this, transa, transb, m, n, k, |
811 | blas::ToDataType<InputType>::value, alpha_ptr, a, |
812 | lda, b, ldb, beta_ptr, c, ldc, precision); |
813 | } |
814 | |
815 | // TODO(parkers): Update all callers to pass kDefaultComputePrecision. |
816 | template <typename InputType, typename ConstantType> |
817 | port::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, |
818 | uint64_t m, uint64 n, uint64 k, ConstantType alpha, |
819 | const DeviceMemory<InputType> &a, int lda, |
820 | const DeviceMemory<InputType> &b, int ldb, |
821 | ConstantType beta, DeviceMemory<InputType> *c, |
822 | int ldc) { |
823 | return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, |
824 | ldc, blas::kDefaultComputePrecision); |
825 | } |
826 | |
827 | Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, |
828 | blas::Transpose transb, uint64_t m, |
829 | uint64 n, uint64_t k, float alpha, |
830 | const DeviceMemory<Eigen::half> &a, int lda, |
831 | const DeviceMemory<Eigen::half> &b, int ldb, |
832 | float beta, DeviceMemory<Eigen::half> *c, |
833 | int ldc, |
834 | blas::ProfileResult *output_profile_result); |
835 | Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, |
836 | blas::Transpose transb, uint64_t m, |
837 | uint64 n, uint64_t k, float alpha, |
838 | const DeviceMemory<float> &a, int lda, |
839 | const DeviceMemory<float> &b, int ldb, |
840 | float beta, DeviceMemory<float> *c, int ldc, |
841 | blas::ProfileResult *output_profile_result); |
842 | Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, |
843 | blas::Transpose transb, uint64_t m, |
844 | uint64 n, uint64_t k, double alpha, |
845 | const DeviceMemory<double> &a, int lda, |
846 | const DeviceMemory<double> &b, int ldb, |
847 | double beta, DeviceMemory<double> *c, |
848 | int ldc, |
849 | blas::ProfileResult *output_profile_result); |
850 | Stream &ThenBlasGemmWithProfiling( |
851 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
852 | uint64_t k, std::complex<float> alpha, |
853 | const DeviceMemory<std::complex<float>> &a, int lda, |
854 | const DeviceMemory<std::complex<float>> &b, int ldb, |
855 | std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, |
856 | blas::ProfileResult *output_profile_result); |
857 | Stream &ThenBlasGemmWithProfiling( |
858 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
859 | uint64_t k, std::complex<double> alpha, |
860 | const DeviceMemory<std::complex<double>> &a, int lda, |
861 | const DeviceMemory<std::complex<double>> &b, int ldb, |
862 | std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, |
863 | blas::ProfileResult *output_profile_result); |
864 | |
865 | template <typename InputType, typename OutputType> |
866 | port::Status ThenBlasGemmWithAlgorithm( |
867 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
868 | uint64_t k, const DeviceMemory<InputType> &a, int lda, |
869 | const DeviceMemory<InputType> &b, int ldb, DeviceMemory<OutputType> *c, |
870 | int ldc, blas::ComputationType computation_type, |
871 | blas::AlgorithmType algorithm, |
872 | blas::ProfileResult *output_profile_result) { |
873 | OutputType alpha{1}; |
874 | OutputType beta{0}; |
875 | return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b, |
876 | ldb, beta, c, ldc, computation_type, |
877 | algorithm, blas::kDefaultComputePrecision, |
878 | output_profile_result); |
879 | } |
880 | |
881 | template <typename InputType, typename OutputType, typename ConstantType> |
882 | port::Status ThenBlasGemmWithAlgorithm( |
883 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
884 | uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda, |
885 | const DeviceMemory<InputType> &b, int ldb, ConstantType beta, |
886 | DeviceMemory<OutputType> *c, int ldc, |
887 | blas::ComputationType computation_type, blas::AlgorithmType algorithm, |
888 | blas::ComputePrecision precision, |
889 | blas::ProfileResult *output_profile_result) { |
890 | TF_RETURN_IF_ERROR( |
891 | CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>( |
892 | computation_type)); |
893 | |
894 | blas::BlasSupport *blas = parent()->AsBlas(); |
895 | if (!blas) { |
896 | return port::InternalError( |
897 | "Attempting to perform BLAS operation using " |
898 | "StreamExecutor without BLAS support" ); |
899 | } |
900 | |
901 | void *alpha_ptr = α |
902 | void *beta_ptr = β |
903 | float alpha_storage, beta_storage; |
904 | UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
905 | &beta_storage); |
906 | |
907 | port::Status st = blas->DoBlasGemmWithAlgorithm( |
908 | this, transa, transb, m, n, k, alpha_ptr, a, |
909 | blas::ToDataType<InputType>::value, lda, b, |
910 | blas::ToDataType<InputType>::value, ldb, beta_ptr, c, |
911 | blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm, |
912 | precision, output_profile_result); |
913 | if (output_profile_result) { |
914 | // The error is recorded in the profile. |
915 | return ::tsl::OkStatus(); |
916 | } |
917 | return st; |
918 | } |
919 | |
920 | template <typename InputType, typename OutputType, typename ConstantType> |
921 | port::Status ThenBlasGemmStridedBatchedWithAlgorithm( |
922 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
923 | uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda, |
924 | int64_t stride_a, const DeviceMemory<InputType> &b, int ldb, |
925 | int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc, |
926 | int64_t stride_c, int batch_count, blas::ComputationType computation_type, |
927 | blas::AlgorithmType algorithm, blas::ComputePrecision precision, |
928 | blas::ProfileResult *output_profile_result) { |
929 | TF_RETURN_IF_ERROR( |
930 | CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>( |
931 | computation_type)); |
932 | |
933 | blas::BlasSupport *blas = parent()->AsBlas(); |
934 | if (!blas) { |
935 | return port::InternalError( |
936 | "Attempting to perform BLAS operation using " |
937 | "StreamExecutor without BLAS support" ); |
938 | } |
939 | void *alpha_ptr = α |
940 | void *beta_ptr = β |
941 | float alpha_storage, beta_storage; |
942 | UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
943 | &beta_storage); |
944 | port::Status st = blas->DoBlasGemmStridedBatchedWithAlgorithm( |
945 | this, transa, transb, m, n, k, alpha_ptr, a, |
946 | blas::ToDataType<InputType>::value, lda, stride_a, b, |
947 | blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c, |
948 | blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count, |
949 | computation_type, algorithm, precision, output_profile_result); |
950 | if (output_profile_result) { |
951 | // The error is recorded in the profile. |
952 | return ::tsl::OkStatus(); |
953 | } |
954 | return st; |
955 | } |
956 | |
957 | template <typename T> |
958 | using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok |
959 | |
960 | // See BlasSupport::DoBlasGemmBatched. |
961 | Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, |
962 | uint64_t m, uint64 n, uint64_t k, float alpha, |
963 | const DeviceMemorySlice<Eigen::half> &a, int lda, |
964 | const DeviceMemorySlice<Eigen::half> &b, int ldb, |
965 | float beta, |
966 | const DeviceMemorySlice<Eigen::half> &c, int ldc, |
967 | int batch_count); |
968 | Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, |
969 | uint64_t m, uint64 n, uint64 k, float alpha, |
970 | const DeviceMemorySlice<float> &a, int lda, |
971 | const DeviceMemorySlice<float> &b, int ldb, |
972 | float beta, const DeviceMemorySlice<float> &c, |
973 | int ldc, int batch_count); |
974 | Stream &ThenBlasGemmBatched( |
975 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
976 | uint64 k, double alpha, |
977 | const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok |
978 | int lda, |
979 | const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok |
980 | int ldb, double beta, |
981 | const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok |
982 | int ldc, int batch_count); |
983 | Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, |
984 | uint64_t m, uint64 n, uint64_t k, |
985 | std::complex<float> alpha, |
986 | const DeviceMemorySlice<std::complex<float>> &a, |
987 | int lda, |
988 | const DeviceMemorySlice<std::complex<float>> &b, |
989 | int ldb, std::complex<float> beta, |
990 | const DeviceMemorySlice<std::complex<float>> &c, |
991 | int ldc, int batch_count); |
992 | Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, |
993 | uint64_t m, uint64 n, uint64_t k, |
994 | std::complex<double> alpha, |
995 | const DeviceMemorySlice<std::complex<double>> &a, |
996 | int lda, |
997 | const DeviceMemorySlice<std::complex<double>> &b, |
998 | int ldb, std::complex<double> beta, |
999 | const DeviceMemorySlice<std::complex<double>> &c, |
1000 | int ldc, int batch_count); |
1001 | Stream &ThenBlasGemmBatchedWithScratch( |
1002 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
1003 | uint64_t k, float alpha, const DeviceMemorySlice<Eigen::half> &a, int lda, |
1004 | const DeviceMemorySlice<Eigen::half> &b, int ldb, float beta, |
1005 | const DeviceMemorySlice<Eigen::half> &c, int ldc, int batch_count, |
1006 | ScratchAllocator *scratch_allocator); |
1007 | Stream &ThenBlasGemmBatchedWithScratch( |
1008 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
1009 | uint64_t k, float alpha, const DeviceMemorySlice<float> &a, int lda, |
1010 | const DeviceMemorySlice<float> &b, int ldb, float beta, |
1011 | const DeviceMemorySlice<float> &c, int ldc, int batch_count, |
1012 | ScratchAllocator *scratch_allocator); |
1013 | Stream &ThenBlasGemmBatchedWithScratch( |
1014 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
1015 | uint64_t k, double alpha, const DeviceMemorySlice<double> &a, int lda, |
1016 | const DeviceMemorySlice<double> &b, int ldb, double beta, |
1017 | const DeviceMemorySlice<double> &c, int ldc, int batch_count, |
1018 | ScratchAllocator *scratch_allocator); |
1019 | Stream &ThenBlasGemmBatchedWithScratch( |
1020 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
1021 | uint64_t k, std::complex<float> alpha, |
1022 | const DeviceMemorySlice<std::complex<float>> &a, int lda, |
1023 | const DeviceMemorySlice<std::complex<float>> &b, int ldb, |
1024 | std::complex<float> beta, const DeviceMemorySlice<std::complex<float>> &c, |
1025 | int ldc, int batch_count, ScratchAllocator *scratch_allocator); |
1026 | Stream &ThenBlasGemmBatchedWithScratch( |
1027 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
1028 | uint64_t k, std::complex<double> alpha, |
1029 | const DeviceMemorySlice<std::complex<double>> &a, int lda, |
1030 | const DeviceMemorySlice<std::complex<double>> &b, int ldb, |
1031 | std::complex<double> beta, |
1032 | const DeviceMemorySlice<std::complex<double>> &c, int ldc, |
1033 | int batch_count, ScratchAllocator *scratch_allocator); |
1034 | |
1035 | template <typename InputType, typename ConstantType> |
1036 | port::Status ThenBlasGemmStridedBatched( |
1037 | blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, |
1038 | uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda, |
1039 | int64_t stride_a, const DeviceMemory<InputType> &b, int ldb, |
1040 | int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc, |
1041 | int64_t stride_c, int batch_count, blas::ComputePrecision precision) { |
1042 | static_assert( |
1043 | detail::is_any_of<InputType, float, Eigen::half, Eigen::bfloat16, |
1044 | double, std::complex<float>, std::complex<double>>(), |
1045 | "Unsupported input type" ); |
1046 | static_assert( |
1047 | std::is_same_v<ConstantType, InputType> || |
1048 | (detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16>() && |
1049 | std::is_same_v<ConstantType, float>), |
1050 | "Mismatched input and alpha/beta types" ); |
1051 | blas::BlasSupport *blas = parent()->AsBlas(); |
1052 | if (!blas) { |
1053 | return port::InternalError( |
1054 | "Attempting to perform BLAS operation using " |
1055 | "StreamExecutor without BLAS support" ); |
1056 | } |
1057 | |
1058 | void *alpha_ptr = α |
1059 | void *beta_ptr = β |
1060 | float alpha_storage, beta_storage; |
1061 | UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage, |
1062 | &beta_storage); |
1063 | |
1064 | return blas->DoBlasGemmStridedBatched( |
1065 | this, transa, transb, m, n, k, blas::ToDataType<InputType>::value, |
1066 | alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, |
1067 | stride_c, batch_count, precision); |
1068 | } |
1069 | |
1070 | // See BlasSupport::DoBlasTrsm. |
1071 | Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
1072 | blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
1073 | uint64_t n, float alpha, const DeviceMemory<float> &a, |
1074 | int lda, DeviceMemory<float> *b, int ldb); |
1075 | Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
1076 | blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
1077 | uint64_t n, double alpha, const DeviceMemory<double> &a, |
1078 | int lda, DeviceMemory<double> *b, int ldb); |
1079 | Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
1080 | blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
1081 | uint64_t n, std::complex<float> alpha, |
1082 | const DeviceMemory<std::complex<float>> &a, int lda, |
1083 | DeviceMemory<std::complex<float>> *b, int ldb); |
1084 | Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, |
1085 | blas::Transpose transa, blas::Diagonal diag, uint64_t m, |
1086 | uint64_t n, std::complex<double> alpha, |
1087 | const DeviceMemory<std::complex<double>> &a, int lda, |
1088 | DeviceMemory<std::complex<double>> *b, int ldb); |
1089 | |
1090 | // See BlasSupport::DoBlasTrsmBatched. |
1091 | Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
1092 | blas::Transpose transa, blas::Diagonal diag, |
1093 | uint64_t m, uint64 n, float alpha, |
1094 | const DeviceMemory<float *> &as, int lda, |
1095 | DeviceMemory<float *> *bs, int ldb, |
1096 | int batch_count); |
1097 | Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
1098 | blas::Transpose transa, blas::Diagonal diag, |
1099 | uint64_t m, uint64 n, double alpha, |
1100 | const DeviceMemory<double *> &as, int lda, |
1101 | DeviceMemory<double *> *bs, int ldb, |
1102 | int batch_count); |
1103 | Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
1104 | blas::Transpose transa, blas::Diagonal diag, |
1105 | uint64_t m, uint64 n, std::complex<float> alpha, |
1106 | const DeviceMemory<std::complex<float> *> &as, |
1107 | int lda, DeviceMemory<std::complex<float> *> *bs, |
1108 | int ldb, int batch_count); |
1109 | Stream &ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo, |
1110 | blas::Transpose transa, blas::Diagonal diag, |
1111 | uint64_t m, uint64 n, std::complex<double> alpha, |
1112 | const DeviceMemory<std::complex<double> *> &as, |
1113 | int lda, DeviceMemory<std::complex<double> *> *bs, |
1114 | int ldb, int batch_count); |
1115 | |
1116 | // See FftSupport::DoFft. |
1117 | Stream &ThenFft(fft::Plan *plan, |
1118 | const DeviceMemory<std::complex<float>> &input, |
1119 | DeviceMemory<std::complex<float>> *output); |
1120 | Stream &ThenFft(fft::Plan *plan, |
1121 | const DeviceMemory<std::complex<double>> &input, |
1122 | DeviceMemory<std::complex<double>> *output); |
1123 | Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, |
1124 | DeviceMemory<std::complex<float>> *output); |
1125 | Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, |
1126 | DeviceMemory<std::complex<double>> *output); |
1127 | Stream &ThenFft(fft::Plan *plan, |
1128 | const DeviceMemory<std::complex<float>> &input, |
1129 | DeviceMemory<float> *output); |
1130 | Stream &ThenFft(fft::Plan *plan, |
1131 | const DeviceMemory<std::complex<double>> &input, |
1132 | DeviceMemory<double> *output); |
1133 | |
1134 | // Makes the RNG use the provided value as the basis for further generation. |
1135 | // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good |
1136 | // sources of seed data if the default (high quality) sources are not |
1137 | // desired. |
1138 | // For most use cases, this function will not be necessary; each provided |
1139 | // back-end implementation will be appropriately seeded by default. |
1140 | // At a minimum 16 bytes of data are required in the seed buffer. |
1141 | // |
1142 | // To seed with good (non-reproducible) data: |
1143 | // File* f = File::Open("/dev/random", "r"); |
1144 | // int64_t bytes_read = f->Read(seed_data, bytes_to_read); |
1145 | // < error checking > |
1146 | // stream.ThenSetRngSeed(seed_data, bytes_read); |
1147 | // |
1148 | // To seed with reproducible data: |
1149 | // uint64_t seed_data[2] = { <data> }; |
1150 | // stream.ThenSetRngSeed(seed_data, 16); |
1151 | Stream &ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes); |
1152 | |
1153 | // Populates the memory indicated by values with uniform-random-distribution |
1154 | // values. TODO(leary) seeding API/description |
1155 | // |
1156 | // Uses the type and size of the DeviceMemory to infer what data should be |
1157 | // populated. |
1158 | Stream &ThenPopulateRandUniform(DeviceMemory<float> *values); |
1159 | Stream &ThenPopulateRandUniform(DeviceMemory<double> *values); |
1160 | Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values); |
1161 | Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values); |
1162 | Stream &ThenPopulateRandGaussian(float mean, float stddev, |
1163 | DeviceMemory<float> *values); |
1164 | Stream &ThenPopulateRandGaussian(double mean, double stddev, |
1165 | DeviceMemory<double> *values); |
1166 | |
1167 | // Entrain onto the stream: a memcpy to a host destination from a GPU source |
1168 | // of the given target size. host_dst must be a pointer to host memory |
1169 | // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and |
1170 | // then registered with StreamExecutor::HostMemoryRegister. |
1171 | Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, |
1172 | uint64_t size); |
1173 | |
1174 | // Entrain onto the stream: a memcpy to a GPU destination from a host source |
1175 | // of the given target size. host_src must be a pointer to host memory |
1176 | // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and |
1177 | // then registered with StreamExecutor::HostMemoryRegister. |
1178 | Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, |
1179 | uint64_t size); |
1180 | |
1181 | // Alternative interface for memcpying from device to host that takes an |
1182 | // array slice. Checks that the destination size can accommodate the host |
1183 | // slice size. |
1184 | template <typename T> |
1185 | Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src, |
1186 | port::MutableArraySlice<T> host_dst) { |
1187 | auto host_size = host_dst.size() * sizeof(T); |
1188 | CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size()); |
1189 | return ThenMemcpy(host_dst.begin(), gpu_src, host_size); |
1190 | } |
1191 | |
1192 | // Alternative interface for memcpying from host to device that takes an |
1193 | // array slice. Checks that the destination size can accommodate the host |
1194 | // slice size. |
1195 | template <typename T> |
1196 | Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src, // non-absl ok |
1197 | DeviceMemory<T> *gpu_dst) { |
1198 | auto host_size = host_src.size() * sizeof(T); |
1199 | CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size); |
1200 | return ThenMemcpy(gpu_dst, host_src.begin(), host_size); |
1201 | } |
1202 | |
1203 | // Entrain onto the stream: a memcpy to a GPU destination from a GPU source |
1204 | // of the given target size. gpu_src/dst must be pointers to GPU memory and |
1205 | // peer access must be enabled between their owning StreamExecutors. |
1206 | Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, |
1207 | uint64_t size); |
1208 | |
1209 | // Calls to the device-to-device copy overload of ThenMemcpy -- useful for |
1210 | // ensuring that the host pointer isn't getting confused accidentally with a |
1211 | // device pointer if you're not doing metaprogramming against the API. |
1212 | Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst, |
1213 | const DeviceMemoryBase &gpu_src, uint64_t size) { |
1214 | return ThenMemcpy(gpu_dst, gpu_src, size); |
1215 | } |
1216 | |
1217 | // Entrain onto the stream: a memset of zero at a GPU location of size bytes. |
1218 | // The location must not be null. |
1219 | Stream &ThenMemZero(DeviceMemoryBase *location, uint64_t size); |
1220 | |
1221 | // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of |
1222 | // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible |
1223 | // by 4). The location must not be null. |
1224 | Stream &ThenMemset32(DeviceMemoryBase *location, uint32_t pattern, |
1225 | uint64_t size); |
1226 | |
1227 | // Enqueue a forward operation of the RNN model onto the stream. |
1228 | // See DnnSupport::DoRnnForward for more details. |
1229 | Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, |
1230 | const dnn::RnnSequenceTensorDescriptor &input_desc, |
1231 | const DeviceMemory<Eigen::half> &input_data, |
1232 | const DeviceMemory<int> &seq_lengths_data, |
1233 | const dnn::RnnStateTensorDescriptor &input_h_desc, |
1234 | const DeviceMemory<Eigen::half> &input_h_data, |
1235 | const dnn::RnnStateTensorDescriptor &input_c_desc, |
1236 | const DeviceMemory<Eigen::half> &input_c_data, |
1237 | const DeviceMemory<Eigen::half> ¶ms, |
1238 | const dnn::RnnSequenceTensorDescriptor &output_desc, |
1239 | DeviceMemory<Eigen::half> *output_data, |
1240 | const dnn::RnnStateTensorDescriptor &output_h_desc, |
1241 | DeviceMemory<Eigen::half> *output_h_data, |
1242 | const dnn::RnnStateTensorDescriptor &output_c_desc, |
1243 | DeviceMemory<Eigen::half> *output_c_data, |
1244 | bool is_training, |
1245 | ScratchAllocator *reserve_space_allocator, |
1246 | ScratchAllocator *workspace_allocator, |
1247 | dnn::ProfileResult *output_profile_result); |
1248 | |
1249 | Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, |
1250 | const dnn::RnnSequenceTensorDescriptor &input_desc, |
1251 | const DeviceMemory<float> &input_data, |
1252 | const DeviceMemory<int> &seq_lengths_data, |
1253 | const dnn::RnnStateTensorDescriptor &input_h_desc, |
1254 | const DeviceMemory<float> &input_h_data, |
1255 | const dnn::RnnStateTensorDescriptor &input_c_desc, |
1256 | const DeviceMemory<float> &input_c_data, |
1257 | const DeviceMemory<float> ¶ms, |
1258 | const dnn::RnnSequenceTensorDescriptor &output_desc, |
1259 | DeviceMemory<float> *output_data, |
1260 | const dnn::RnnStateTensorDescriptor &output_h_desc, |
1261 | DeviceMemory<float> *output_h_data, |
1262 | const dnn::RnnStateTensorDescriptor &output_c_desc, |
1263 | DeviceMemory<float> *output_c_data, bool is_training, |
1264 | ScratchAllocator *reserve_space_allocator, |
1265 | ScratchAllocator *workspace_allocator, |
1266 | dnn::ProfileResult *output_profile_result); |
1267 | |
1268 | Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, |
1269 | const dnn::RnnSequenceTensorDescriptor &input_desc, |
1270 | const DeviceMemory<double> &input_data, |
1271 | const DeviceMemory<int> &seq_lengths_data, |
1272 | const dnn::RnnStateTensorDescriptor &input_h_desc, |
1273 | const DeviceMemory<double> &input_h_data, |
1274 | const dnn::RnnStateTensorDescriptor &input_c_desc, |
1275 | const DeviceMemory<double> &input_c_data, |
1276 | const DeviceMemory<double> ¶ms, |
1277 | const dnn::RnnSequenceTensorDescriptor &output_desc, |
1278 | DeviceMemory<double> *output_data, |
1279 | const dnn::RnnStateTensorDescriptor &output_h_desc, |
1280 | DeviceMemory<double> *output_h_data, |
1281 | const dnn::RnnStateTensorDescriptor &output_c_desc, |
1282 | DeviceMemory<double> *output_c_data, bool is_training, |
1283 | ScratchAllocator *reserve_space_allocator, |
1284 | ScratchAllocator *workspace_allocator, |
1285 | dnn::ProfileResult *output_profile_result); |
1286 | |
1287 | // Enqueue a backward operation of the RNN model onto the stream. |
1288 | // See DnnSupport::DoRnnBackward for more details. |
1289 | Stream &ThenRnnBackward( |
1290 | const dnn::RnnDescriptor &rnn_desc, |
1291 | const dnn::RnnSequenceTensorDescriptor &input_desc, |
1292 | const DeviceMemory<Eigen::half> &input_data, |
1293 | const DeviceMemory<int> &seq_lengths_data, |
1294 | const dnn::RnnStateTensorDescriptor &input_h_desc, |
1295 | const DeviceMemory<Eigen::half> &input_h_data, |
1296 | const dnn::RnnStateTensorDescriptor &input_c_desc, |
1297 | const DeviceMemory<Eigen::half> &input_c_data, |
1298 | const DeviceMemory<Eigen::half> ¶ms, |
1299 | const dnn::RnnSequenceTensorDescriptor &output_desc, |
1300 | const DeviceMemory<Eigen::half> &output_data, |
1301 | const dnn::RnnStateTensorDescriptor &output_h_desc, |
1302 | const DeviceMemory<Eigen::half> &output_h_data, |
1303 | const dnn::RnnStateTensorDescriptor &output_c_desc, |
1304 | const DeviceMemory<Eigen::half> &output_c_data, |
1305 | const DeviceMemory<Eigen::half> &output_backprop_data, |
1306 | const DeviceMemory<Eigen::half> &output_h_backprop_data, |
1307 | const DeviceMemory<Eigen::half> &output_c_backprop_data, |
1308 | DeviceMemory<Eigen::half> *input_backprop_data, |
1309 | DeviceMemory<Eigen::half> *input_h_backprop_data, |
1310 | DeviceMemory<Eigen::half> *input_c_backprop_data, |
1311 | DeviceMemory<Eigen::half> *params_backprop_data, |
1312 | DeviceMemory<uint8_t> *reserve_space_data, |
1313 | ScratchAllocator *workspace_allocator, |
1314 | dnn::ProfileResult *output_profile_result); |
1315 | |
1316 | Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, |
1317 | const dnn::RnnSequenceTensorDescriptor &input_desc, |
1318 | const DeviceMemory<float> &input_data, |
1319 | const DeviceMemory<int> &seq_lengths_data, |
1320 | const dnn::RnnStateTensorDescriptor &input_h_desc, |
1321 | const DeviceMemory<float> &input_h_data, |
1322 | const dnn::RnnStateTensorDescriptor &input_c_desc, |
1323 | const DeviceMemory<float> &input_c_data, |
1324 | const DeviceMemory<float> ¶ms, |
1325 | const dnn::RnnSequenceTensorDescriptor &output_desc, |
1326 | const DeviceMemory<float> &output_data, |
1327 | const dnn::RnnStateTensorDescriptor &output_h_desc, |
1328 | const DeviceMemory<float> &output_h_data, |
1329 | const dnn::RnnStateTensorDescriptor &output_c_desc, |
1330 | const DeviceMemory<float> &output_c_data, |
1331 | const DeviceMemory<float> &output_backprop_data, |
1332 | const DeviceMemory<float> &output_h_backprop_data, |
1333 | const DeviceMemory<float> &output_c_backprop_data, |
1334 | DeviceMemory<float> *input_backprop_data, |
1335 | DeviceMemory<float> *input_h_backprop_data, |
1336 | DeviceMemory<float> *input_c_backprop_data, |
1337 | DeviceMemory<float> *params_backprop_data, |
1338 | DeviceMemory<uint8_t> *reserve_space_data, |
1339 | ScratchAllocator *workspace_allocator, |
1340 | dnn::ProfileResult *output_profile_result); |
1341 | |
1342 | Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, |
1343 | const dnn::RnnSequenceTensorDescriptor &input_desc, |
1344 | const DeviceMemory<double> &input_data, |
1345 | const DeviceMemory<int> &seq_lengths_data, |
1346 | const dnn::RnnStateTensorDescriptor &input_h_desc, |
1347 | const DeviceMemory<double> &input_h_data, |
1348 | const dnn::RnnStateTensorDescriptor &input_c_desc, |
1349 | const DeviceMemory<double> &input_c_data, |
1350 | const DeviceMemory<double> ¶ms, |
1351 | const dnn::RnnSequenceTensorDescriptor &output_desc, |
1352 | const DeviceMemory<double> &output_data, |
1353 | const dnn::RnnStateTensorDescriptor &output_h_desc, |
1354 | const DeviceMemory<double> &output_h_data, |
1355 | const dnn::RnnStateTensorDescriptor &output_c_desc, |
1356 | const DeviceMemory<double> &output_c_data, |
1357 | const DeviceMemory<double> &output_backprop_data, |
1358 | const DeviceMemory<double> &output_h_backprop_data, |
1359 | const DeviceMemory<double> &output_c_backprop_data, |
1360 | DeviceMemory<double> *input_backprop_data, |
1361 | DeviceMemory<double> *input_h_backprop_data, |
1362 | DeviceMemory<double> *input_c_backprop_data, |
1363 | DeviceMemory<double> *params_backprop_data, |
1364 | DeviceMemory<uint8_t> *reserve_space_data, |
1365 | ScratchAllocator *workspace_allocator, |
1366 | dnn::ProfileResult *output_profile_result); |
1367 | |
1368 | // Enqueue a CTCLoss operation onto the stream. |
1369 | // See DnnSupport::DoCtcLoss for more details. |
1370 | Stream &ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc, |
1371 | const DeviceMemory<float> &probs_data, |
1372 | absl::Span<const int> labels_data, |
1373 | absl::Span<const int> labels_lengths_data, |
1374 | absl::Span<const int> input_lengths_data, |
1375 | DeviceMemory<float> *costs_data, |
1376 | const dnn::RnnStateTensorDescriptor &grads_desc, |
1377 | DeviceMemory<float> *grads_data, |
1378 | ScratchAllocator *workspace_allocator); |
1379 | |
1380 | // Enqueue onto the stream a operation that transforms a tensor. |
1381 | // See DnnSupport::DoTransformTensor for more details. |
1382 | Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, |
1383 | dnn::DataType input_type, |
1384 | const DeviceMemoryBase &input_data, |
1385 | const dnn::BatchDescriptor &output_desc, |
1386 | dnn::DataType output_type, float scale, |
1387 | DeviceMemoryBase *output_data); |
1388 | |
1389 | // The templated version of the above ThenTransformTensor. Useful when the |
1390 | // input and output types are statically known. |
1391 | template <typename InElemT, typename OutElemT> |
1392 | Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, |
1393 | const DeviceMemory<InElemT> &input_data, |
1394 | const dnn::BatchDescriptor &output_desc, |
1395 | DeviceMemory<OutElemT> *output_data) { |
1396 | return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(), |
1397 | input_data, output_desc, |
1398 | dnn::ToDataType<OutElemT>(), output_data); |
1399 | } |
1400 | |
1401 | // (Synchronously) block the host code waiting for the operations |
1402 | // entrained on the stream (enqueued to this point in program |
1403 | // execution) to complete. |
1404 | // |
1405 | // Returns an OK status if the blocking was successful and the stream is ok(). |
1406 | // Otherwise returns an error describing why the blocking failed. |
1407 | port::Status BlockHostUntilDone() TF_LOCKS_EXCLUDED(mu_); |
1408 | |
1409 | // Warning! This method interacts with internal threads in |
1410 | // sometimes-unpredictable ways and is intended for GPU-Executor-internal |
1411 | // use |
1412 | // only. Please check with a member of the FASTR team before making use of |
1413 | // this method. |
1414 | // |
1415 | // Entrains onto the stream a function to be executed on the host at some |
1416 | // point in the future. |
1417 | // Async host callbacks DO NOT block the stream as device functions (or as |
1418 | // synchronous host callbacks). No synchronization is possible with |
1419 | // asynchronous callbacks; they are strictly fire-and-forget. |
1420 | // This method is private due to the potential for undefined behavior with |
1421 | // synchronization using OpenCL user events. |
1422 | // The ONLY lifetime guarantee in these calls is that the StreamExecutor |
1423 | // parameter will still be valid - this Stream may not be! |
1424 | // Any callbacks requiring device API calls must use this method. |
1425 | Stream &ThenEnqueueOnBackgroundThread( |
1426 | std::function<void(StreamExecutor *)> task); |
1427 | |
1428 | // Returns the (opaque) platform-specific backing object. Ownership is not |
1429 | // transferred to the caller. |
1430 | internal::StreamInterface *implementation() { return implementation_.get(); } |
1431 | |
1432 | // Entrains onto the stream a callback to the host (from the device). |
1433 | // Behaves as ThenDoHostCallbackWithStatus below, but the callback should |
1434 | // never fail or its failure is inconsequential. |
1435 | // |
1436 | // This is kept for backward compatibility. Future code should use |
1437 | // ThenDoHostCallbackWithStatus and explicitly return a success status. |
1438 | // TODO(b/112125301): Eventually remove this method. |
1439 | Stream &ThenDoHostCallback(std::function<void()> callback); |
1440 | |
1441 | // Entrains onto the stream a callback to the host (from the device). |
1442 | // Host callbacks block/occupy the stream just as device functions |
1443 | // (execute one at a time, block later stream operations). |
1444 | // Whether the callback return status affects the result of BlockHostUntilDone |
1445 | // is platform-dependent. |
1446 | // |
1447 | // Behavior is undefined when synchronizing using OpenCL user events. |
1448 | // Behavior is undefined if host callbacks call device routines or insert |
1449 | // them into any stream. |
1450 | // |
1451 | // On certain platforms, ThenDoHostCallback is expected to have significant |
1452 | // negative effects on performance. |
1453 | Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback); |
1454 | |
1455 | // Runs the given callback after the next call to BlockHostUntilDone on this |
1456 | // stream (or after the Stream does BlockHostUntilDone in its destructor). |
1457 | // This can act as a faster alternative to ThenDoHostCallbackWithStatus for |
1458 | // some use cases. |
1459 | Stream &ThenRunAfterNextBlockHostUntilDone(std::function<void()> callback); |
1460 | |
1461 | // Returns the StreamExecutor (parent object) associated with this stream. |
1462 | StreamExecutor *parent() const { |
1463 | CHECK(parent_ != nullptr); |
1464 | return parent_; |
1465 | } |
1466 | |
1467 | // |
1468 | CudaComputeCapability GetCudaComputeCapability() const { |
1469 | return parent()->GetDeviceDescription().cuda_compute_capability(); |
1470 | } |
1471 | |
1472 | RocmComputeCapability GetRocmComputeCapability() const { |
1473 | return parent()->GetDeviceDescription().rocm_compute_capability(); |
1474 | } |
1475 | // Returns the (internal usage) temporary-memory-allocation manager associated |
1476 | // with this stream. |
1477 | internal::TemporaryMemoryManager *temporary_memory_manager(); |
1478 | |
1479 | // Returns a debugging string "[stream=0x...,impl=0x...]". |
1480 | std::string DebugStreamPointers() const; |
1481 | |
1482 | private: |
1483 | friend class host::HostBlas; // for parent_. |
1484 | friend class host::HostFft; // for parent_. |
1485 | friend class host::HostRng; // for parent_. |
1486 | template <typename... Args> |
1487 | friend struct ThenBlasImpl; // for implementing ThenBlasXXX. |
1488 | friend class ocl::CLBlas; // for parent_. |
1489 | |
1490 | // Checks whether types match before a call to extended BLAS version. |
1491 | template <typename ABType, typename CType, typename ScaleType> |
1492 | port::Status CheckTypesForExtendedBlas( |
1493 | blas::ComputationType computation_type) { |
1494 | static_assert( |
1495 | detail::is_any_of<ABType, Eigen::half, Eigen::bfloat16, float, double, |
1496 | int8_t, std::complex<float>, std::complex<double>>(), |
1497 | "The only buffer types supported are: Eigen::half, float, " |
1498 | "double, int8, std::complex<float> and std::complex<double>" ); |
1499 | static_assert( |
1500 | std::is_same_v<ABType, CType> || |
1501 | (std::is_same_v<ABType, int8_t> && std::is_same_v<CType, int32_t>), |
1502 | "Input and output buffer types should be the same unless input is " |
1503 | "int8 and output is int32" ); |
1504 | static_assert( |
1505 | std::is_same_v<ScaleType, CType> || |
1506 | (std::is_same_v<ScaleType, float> && |
1507 | detail::is_any_of<CType, Eigen::half, Eigen::bfloat16>()), |
1508 | "Mismatched alpha/beta and output types" ); |
1509 | |
1510 | bool valid_computation_type = [computation_type] { |
1511 | switch (computation_type) { |
1512 | case blas::ComputationType::kF16: |
1513 | return std::is_same_v<CType, Eigen::half>; |
1514 | case blas::ComputationType::kF32: |
1515 | return detail::is_any_of<CType, Eigen::half, Eigen::bfloat16, float, |
1516 | std::complex<float>>(); |
1517 | case blas::ComputationType::kF64: |
1518 | return detail::is_any_of<CType, double, std::complex<double>>(); |
1519 | case blas::ComputationType::kI32: |
1520 | return std::is_same_v<CType, int32_t>; |
1521 | case blas::ComputationType::kF16AsF32: // fall-through |
1522 | case blas::ComputationType::kBF16AsF32: // fall-through |
1523 | case blas::ComputationType::kTF32AsF32: |
1524 | return detail::is_any_of<CType, float, std::complex<float>>(); |
1525 | } |
1526 | }(); |
1527 | |
1528 | if (!valid_computation_type) { |
1529 | return port::InternalError(absl::StrCat( |
1530 | "Invalid computation type " , |
1531 | blas::ComputationTypeString(computation_type), " for output type: " , |
1532 | blas::DataTypeString(blas::ToDataType<CType>::value))); |
1533 | } |
1534 | return ::tsl::OkStatus(); |
1535 | } |
1536 | |
1537 | bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { |
1538 | absl::ReaderMutexLock lock(&mu_); |
1539 | return !status_.ok(); |
1540 | } |
1541 | |
1542 | // Sets the error state if operation_retcode is false. |
1543 | // This is a useful shorthand for many stream routines. |
1544 | void CheckError(bool operation_retcode) TF_LOCKS_EXCLUDED(mu_); |
1545 | |
1546 | // Checks the status and logs the error message, if any. |
1547 | void CheckStatus(port::Status status) TF_LOCKS_EXCLUDED(mu_); |
1548 | |
1549 | void SetError() { CheckError(false /* = operation_retcode */); } |
1550 | |
1551 | void SetErrorAndLogNoDnnSupport() { |
1552 | SetError(); |
1553 | LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " |
1554 | "without DNN support" ; |
1555 | } |
1556 | |
1557 | // Runs the set of callbacks that are intended to run after |
1558 | // BlockHostUntilDone. |
1559 | void RunAfterBlockHostUntilDoneCallbacks(); |
1560 | |
1561 | // The StreamExecutor that supports the operation of this stream. |
1562 | StreamExecutor *parent_; |
1563 | |
1564 | // The platform-dependent implementation that the StreamExecutor interface |
1565 | // delegates to. |
1566 | std::unique_ptr<internal::StreamInterface> implementation_; |
1567 | |
1568 | // mutex that guards the allocation / error state flags. |
1569 | // Mutable so that it can be obtained via const reader lock. |
1570 | mutable absl::Mutex mu_; |
1571 | |
1572 | // Whether Init() was successfully called to allocate this stream on the |
1573 | // underlying platform. It simply flips from 0 to 1 with a sanity check. |
1574 | // See StreamExecutor::AllocateStream. |
1575 | bool allocated_ ABSL_GUARDED_BY(mu_); |
1576 | |
1577 | // The last error (if any) of all method calls. |
1578 | port::Status status_ ABSL_GUARDED_BY(mu_); |
1579 | |
1580 | // Sub-streams that are generated from this stream. Each element has a pointer |
1581 | // to sub-stream and a boolean value indicating if this substream is ready to |
1582 | // be reused. |
1583 | std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_ |
1584 | ABSL_GUARDED_BY(mu_); |
1585 | |
1586 | // Streams can allocate temporary memories to help with work they enqueue |
1587 | // (e.g. for scratch memory spaces). This member tracks those allocations and |
1588 | // notes when they can be reclaimed -- reclamation is attempted when |
1589 | // BlockHostUntilDone() is called. |
1590 | internal::TemporaryMemoryManager temporary_memory_manager_; |
1591 | |
1592 | // Callbacks enqueued to be run after the next call to BlockHostUntilDone(). |
1593 | std::vector<std::function<void()>> after_block_host_until_done_callbacks_ |
1594 | ABSL_GUARDED_BY(mu_); |
1595 | |
1596 | // Non-extended BLAS interface requires alpha/beta to be floats when input |
1597 | // type is Eigen::half. However, for consistency purposes it is convenient |
1598 | // for the interface to accept Eigen::half. |
1599 | template <typename T> |
1600 | void UpcastHalfToFloat(void **alpha_ptr, void **beta_ptr, |
1601 | float *alpha_storage, float *beta_storage) { |
1602 | if (std::is_same<T, Eigen::half>::value) { |
1603 | *alpha_storage = |
1604 | static_cast<float>(*reinterpret_cast<Eigen::half *>(*alpha_ptr)); |
1605 | *beta_storage = |
1606 | static_cast<float>(*reinterpret_cast<Eigen::half *>(*beta_ptr)); |
1607 | *alpha_ptr = alpha_storage; |
1608 | *beta_ptr = beta_storage; |
1609 | } else if (std::is_same<T, Eigen::bfloat16>::value) { |
1610 | *alpha_storage = |
1611 | static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*alpha_ptr)); |
1612 | *beta_storage = |
1613 | static_cast<float>(*reinterpret_cast<Eigen::bfloat16 *>(*beta_ptr)); |
1614 | *alpha_ptr = alpha_storage; |
1615 | *beta_ptr = beta_storage; |
1616 | } |
1617 | } |
1618 | |
1619 | SE_DISALLOW_COPY_AND_ASSIGN(Stream); |
1620 | }; |
1621 | |
1622 | //////////// |
1623 | // Inlines |
1624 | |
1625 | template <typename... Params, typename... Args> |
1626 | inline port::Status Stream::ThenLaunch(ThreadDim thread_dims, |
1627 | BlockDim block_dims, |
1628 | const TypedKernel<Params...> &kernel, |
1629 | Args... args) { |
1630 | KernelInvocationChecker<std::tuple<Params...>, |
1631 | std::tuple<Args...>>::CheckAllStaticAssert(); |
1632 | |
1633 | // This is the core that allows type-safe kernel launching. |
1634 | // Since the platforms take kernel arguments as tuples of (void *, size), |
1635 | // we pack the variadic parameters passed as ...args into the desired |
1636 | // tuple form and pass that packed form to the StreamExecutor::Launch() |
1637 | // implementation. |
1638 | KernelArgsArray<sizeof...(args)> kernel_args; |
1639 | kernel.PackParams(&kernel_args, args...); |
1640 | TF_RETURN_IF_ERROR( |
1641 | parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args)); |
1642 | return ::tsl::OkStatus(); |
1643 | } |
1644 | |
1645 | template <typename T> |
1646 | inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> |
1647 | Stream::AllocateTemporaryArray(uint64_t element_count) { |
1648 | return temporary_memory_manager_.AllocateArray<T>(element_count); |
1649 | } |
1650 | |
1651 | inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() { |
1652 | return &temporary_memory_manager_; |
1653 | } |
1654 | |
1655 | template <> |
1656 | struct Quantization<uint8_t> { |
1657 | static constexpr dnn::QuantizedActivationMode kModeId = |
1658 | dnn::QuantizedActivationMode::k8Bit; |
1659 | }; |
1660 | |
1661 | template <> |
1662 | struct Quantization<uint16_t> { |
1663 | static constexpr dnn::QuantizedActivationMode kModeId = |
1664 | dnn::QuantizedActivationMode::k16Bit; |
1665 | }; |
1666 | |
1667 | template <> |
1668 | struct Quantization<int32_t> { |
1669 | static constexpr dnn::QuantizedActivationMode kModeId = |
1670 | dnn::QuantizedActivationMode::k32Bit; |
1671 | }; |
1672 | |
1673 | } // namespace stream_executor |
1674 | |
1675 | #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_STREAM_H_ |
1676 | |