3#include <AMReX_Config.H>
9#if defined(AMREX_USE_CUDA)
34 using U = std::conditional_t<std::is_const_v<T>,
Long const,
Long>;
48template <
typename T,
template <
typename>
class V>
70 mat.resize(num_non_zeros);
108template <
typename C,
typename T,
template<
typename>
class AD,
template<
typename>
class AS,
109 std::enable_if_t<std::is_same_v<C,Gpu::HostToDevice> ||
110 std::is_same_v<C,Gpu::DeviceToHost> ||
111 std::is_same_v<C,Gpu::DeviceToDevice>,
int> = 0>
121 dst.
mat.resize(src.
mat.size());
139template <
typename T,
template <
typename>
class V>
142 if (nnz <= 0) {
return; }
146#if defined(AMREX_USE_CUDA) || defined(AMREX_USE_HIP)
151 constexpr int nthreads = 256;
156 auto nr =
int(nrows());
157 int nblocks = (nr + nwarps_per_block-1) / nwarps_per_block;
160 auto* pmat = mat.data();
161 auto* pcol = col_index.data();
162 auto* prow = row_offset.data();
165 auto* d_needs_fallback = needs_fallback.
data();
167 amrex::launch_global<nthreads><<<nblocks, nthreads, 0, stream>>>
171 int r =
int(blockIdx.x)*nwarps_per_block + wid;
174 Long const b = prow[r];
175 Long const e = prow[r+1];
176 auto const len =
int(e - b);
178 if (len <= 1)
return;
184 sorted = sorted && (pcol[b+i-1] <= pcol[b+i]);
186#if defined(AMREX_USE_CUDA)
187 if (__all_sync(0xffffffff, sorted)) {
return; }
189 if (__all(sorted)) {
return; }
195 if (len <= ITEMS_PER_WARP)
197#if defined(AMREX_USE_CUDA)
198 using WarpSort = cub::WarpMergeSort<Long, ITEMS_PER_THREAD, Gpu::Device::warp_size, T>;
199 __shared__
typename WarpSort::TempStorage temp_storage[nwarps_per_block];
200#elif defined(AMREX_USE_HIP)
201 using WarpSort = rocprim::warp_sort<Long, Gpu::Device::warp_size, T>;
202 __shared__
typename WarpSort::storage_type temp_storage[nwarps_per_block];
205 Long keys[ITEMS_PER_THREAD];
206 T values[ITEMS_PER_THREAD];
209 for (
int i = 0; i < ITEMS_PER_THREAD; ++i) {
210 int idx = lane * ITEMS_PER_THREAD + i;
212 keys[i] = pcol[b + idx];
213 values[i] = pmat[b + idx];
215 keys[i] = std::numeric_limits<Long>::max();
221 WarpSort{}.sort(keys, values, temp_storage[wid]),
222 WarpSort(temp_storage[wid]).Sort(
226 for (
int i = 0; i < ITEMS_PER_THREAD; ++i) {
227 int idx = lane * ITEMS_PER_THREAD + i;
229 pcol[b + idx] = keys[i];
230 pmat[b + idx] = values[i];
240 auto* h_needs_fallback = needs_fallback.copyToHost();
242 if (*h_needs_fallback)
244 V<Long> col_index_out(col_index.size());
245 V<T> mat_out(mat.size());
246 auto* d_col_out = col_index_out.data();
247 auto* d_val_out = mat_out.data();
249 std::size_t temp_bytes = 0;
252 rocprim::segmented_radix_sort_pairs,
253 cub::DeviceSegmentedRadixSort::SortPairs)
254 (
nullptr, temp_bytes, pcol, d_col_out, pmat, d_val_out,
255 nnz, nr, prow, prow+1, 0,
int(
sizeof(
Long)*CHAR_BIT),
261 rocprim::segmented_radix_sort_pairs,
262 cub::DeviceSegmentedRadixSort::SortPairs)
263 (d_temp, temp_bytes, pcol, d_col_out, pmat, d_val_out,
264 nnz, nr, prow, prow+1, 0,
int(
sizeof(
Long)*CHAR_BIT),
267 std::swap(col_index, col_index_out);
268 std::swap(mat, mat_out);
278#elif defined(AMREX_USE_SYCL)
297template <
typename T,
template <
typename>
class V>
300 if (nnz <= 0) {
return; }
302 constexpr int SMALL = 128;
320 for (
Long r = 0; r < nr; ++r) {
321 Long const b = row_offset[r ];
322 Long const e = row_offset[r+1];
323 auto const len =
int(e - b);
325 if (len <= 1) {
continue; }
328 for (
int i = 1; i < len; ++i) {
329 if (col_index[b+i-1] > col_index[b+i]) {
334 if (sorted) {
continue; }
338 for (
int i = 0; i < len; ++i) {
339 scols[i] = col_index[b+i];
340 svals[i] = mat [b+i];
342 for (
int i = 1; i < len; ++i) {
346 while (j > 0 && scols[j-1] > c) {
347 scols[j] = scols[j-1];
348 svals[j] = svals[j-1];
354 for (
int i = 0; i < len; ++i) {
355 col_index[b+i] = scols[i];
356 mat [b+i] = svals[i];
363 for (
int i = 0; i < len; ++i) {
364 lcols[i] = col_index[b+i];
365 lvals[i] = mat [b+i];
369 std::sort(perm.begin(), perm.end(),
370 [&] (
int i0,
int i1) {
371 return lcols[i0] < lcols[i1];
374 for (
int out = 0; out < len; ++out) {
375 auto const in = perm[out];
376 col_index[b+out] = lcols[in];
377 mat [b+out] = lvals[in];
#define AMREX_ALWAYS_ASSERT(EX)
Definition AMReX_BLassert.H:50
#define AMREX_RESTRICT
Definition AMReX_Extension.H:32
#define AMREX_HIP_OR_CUDA(a, b)
Definition AMReX_GpuControl.H:17
#define AMREX_GPU_SAFE_CALL(call)
Definition AMReX_GpuError.H:63
#define AMREX_GPU_ERROR_CHECK()
Definition AMReX_GpuError.H:151
#define AMREX_GPU_DEVICE
Definition AMReX_GpuQualifiers.H:18
virtual void free(void *pt)=0
A pure virtual function for deleting the arena pointed to by pt.
virtual void * alloc(std::size_t sz)=0
Definition AMReX_GpuBuffer.H:23
T const * data() const noexcept
Definition AMReX_GpuBuffer.H:50
static constexpr int warp_size
Definition AMReX_GpuDevice.H:236
amrex_long Long
Definition AMReX_INT.H:30
Arena * The_Arena()
Definition AMReX_Arena.cpp:820
__host__ __device__ AMREX_FORCE_INLINE void AddNoRet(T *sum, T value) noexcept
Definition AMReX_GpuAtomic.H:283
void copyAsync(HostToDevice, InIter begin, InIter end, OutIter result) noexcept
A host-to-device copy routine. Note this is just a wrapper around memcpy, so it assumes contiguous st...
Definition AMReX_GpuContainers.H:228
static constexpr DeviceToHost deviceToHost
Definition AMReX_GpuContainers.H:106
static constexpr HostToDevice hostToDevice
Definition AMReX_GpuContainers.H:105
void streamSynchronize() noexcept
Definition AMReX_GpuDevice.H:310
gpuStream_t gpuStream() noexcept
Definition AMReX_GpuDevice.H:291
Definition AMReX_Amr.cpp:50
void duplicateCSR(C c, CSR< T, AD > &dst, CSR< T, AS > const &src)
Copy CSR buffers between memory spaces asynchronously.
Definition AMReX_CSR.H:119
const int[]
Definition AMReX_BLProfiler.cpp:1664
Owning CSR container backed by AMReX resizable vectors.
Definition AMReX_CSR.H:49
V< Long > row_offset
Definition AMReX_CSR.H:52
Long nrows() const
Number of logical rows represented by the CSR offset array.
Definition AMReX_CSR.H:56
Long nnz
Definition AMReX_CSR.H:53
void sort()
Sort each row by column index. Uses GPU acceleration when possible.
Definition AMReX_CSR.H:140
CsrView< T > view()
Mutable view of the underlying buffers.
Definition AMReX_CSR.H:77
void sort_on_host()
Host-only fallback that sorts column indices row by row.
Definition AMReX_CSR.H:298
CsrView< T const > view() const
Const view of the underlying buffers.
Definition AMReX_CSR.H:83
CsrView< T const > const_view() const
Convenience alias for view() const.
Definition AMReX_CSR.H:92
void resize(Long num_rows, Long num_non_zeros)
Resize the storage to accommodate num_rows and num_non_zeros entries.
Definition AMReX_CSR.H:69
V< Long > col_index
Definition AMReX_CSR.H:51
V< T > mat
Definition AMReX_CSR.H:50
Lightweight non-owning CSR view that can point to host or device buffers.
Definition AMReX_CSR.H:33
std::conditional_t< std::is_const_v< T >, Long const, Long > U
Definition AMReX_CSR.H:34
T *__restrict__ mat
Definition AMReX_CSR.H:35
Long nrows
Definition AMReX_CSR.H:39
Long nnz
Definition AMReX_CSR.H:38
U *__restrict__ row_offset
Definition AMReX_CSR.H:37
U *__restrict__ col_index
Definition AMReX_CSR.H:36