After a naive initial port of femto's kernels to CUDA, we had an implementation that looked something like:
xxxxxxxxxx
// start off by setting up our shared memory buffers
uint32_t shmem_offset = 0;
extern __shared__ char shmem[];
uint32_t local_elem = threadIdx.z;
uint32_t elem_per_block = blockDim.z;
nd::view< Connection > shr_connectivity(
(Connection *)(shmem + shmem_offset),
{connectivity.shape[1]}
);
shmem_offset += round_up_to_multiple_of_128(shr_connectivity.size() * sizeof(Connection));
uint32_t nodes_per_elem = el.num_nodes();
nd::view< uint32_t > shr_node_ids(
(uint32_t *)(shmem + shmem_offset),
{nodes_per_elem}
);
shmem_offset += round_up_to_multiple_of_128(shr_node_ids.size() * sizeof(uint32_t));
nd::view< grad_type > shr_du_dxi_q(
(grad_type *)(shmem + shmem_offset),
{qpts_per_elem}
);
shmem_offset += round_up_to_multiple_of_128(shr_du_dxi_q.size() * sizeof(grad_type));
nd::view< double > shr_u_e(
(double *)(shmem + shmem_offset),
{nodes_per_elem}
);
shmem_offset += round_up_to_multiple_of_128(shr_u_e.size() * sizeof(double));
nd::view< double, n > shr_shape_fn_grads(
(double *)(shmem + shmem_offset),
shape_functions.shape
);
shmem_offset += round_up_to_multiple_of_128(shr_shape_fn_grads.size() * sizeof(double));
nd::view< double > shr_scratch(
(double *)(shmem + shmem_offset),
{scratch_size}
);
int elem_tid = threadIdx.x + blockDim.x * threadIdx.y;
int elem_stride = blockDim.x * blockDim.y;
int block_tid = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
int block_stride = blockDim.x * blockDim.y * blockDim.z;
// 0. load shape function evaluations into shared memory
for (int i = block_tid; i < shr_shape_fn_grads.size(); i += block_stride) {
shr_shape_fn_grads[i] = shape_functions[i];
}
// 1. then figure out which element we're integrating and load its connectivity info
uint32_t e = blockIdx.x;
uint32_t elem_id = elements[e];
for (int i = elem_tid; i < shr_connectivity.size(); i += elem_stride) {
shr_connectivity[i] = connectivity(elem_id, i);
}
__syncthreads();
// 2. from the connectivity, figure out which nodes belong to this element
if (elem_tid == 0) {
el.indices(u_offsets, shr_connectivity.data(), shr_node_ids.data());
}
__syncthreads();
int num_components = u.shape[1];
for (int c = 0; c < num_components; c++) {
// 3. load the nodal values for this element
for (int i = elem_tid; i < nodes_per_elem; i += elem_stride) {
shr_u_e(i) = u(shr_node_ids(i), c);
}
__syncthreads();
// 4. interpolate the quadrature point values for this element
el.cuda_gradient(shr_du_dxi_q, shr_u_e, shr_shape_fn_grads, shr_scratch.data());
// 5. write those quadrature values out to global memory
for (int q = elem_tid; q < qpts_per_elem; q += elem_stride) {
du_dxi_q(qpts_per_elem * e + q, c) = shr_du_dxi_q[q];
}
}
This implementation already shows significant speedups for many of the different combinations of element geometry and polynomial order, but it has an issue: it is written so that each thread block processes a single element. This may work okay when that element is a high order tetrahedron and hexahedron, but for low order triangles and quadrilaterals this amounts to launching a kernel with a block size as low as 3.
This may not be a big deal when running on a CPU, but it is problematic for GPUs. On a GPU, calculations are carried out by "warps" (groups of 32 parallel threads), which consequently can operate on up to 32 pieces of data a time. However, if we launch a kernel with a blocksize of only 3, then most of the threads in a given warp will be inactive, which means we only attain a fraction of the performance that the hardware is capable of.
So, how can we better utilize the GPU hardware for elements that only involve 3-4 values per element?
Well, fundamentally there is not enough work in a single 3-node triangle to keep a warp busy, but we can modify the kernel to process multiple elements at a time, rather than just one. That way, we can ensure that each thread block has enough work to keep the hardware busy.
Here's a sketch of how to do that:
Modify the kernel launch parameters
// before: one element / block
int block_size = nodes_per_element;
int grid_size = num_elements;
my_kernel<<<grid_size, block_size>>>(...);
// after: multiple element / block
dim3 block_size = {nodes_per_element, 1, elem_per_block};
int grid_size = (num_elements + elem_per_block - 1) / elem_per_block;
my_kernel<<<grid_size, block_size>>>(...);
Inspect threadIdx.z
in the kernel to figure out which element within the block to work on
Load data and perform calculations on the element corresponding to threadIdx.z
In order to make step 3 work, we also have to increase the amount of shared memory we allocate, to make sure each element has space for its own intermediates. Some of the data (like shape function evaluations) are common to all of the elements in the thread block, so those don't need to be duplicated. Here's some code for the updated body of the kernel, taking into account processing multiple elements per block:
x
// start off by setting up our shared memory buffers
uint32_t shmem_offset = 0;
extern __shared__ char shmem[];
uint32_t elem_per_block = blockDim.z;
nd::view< Connection, 2 > shr_connectivity(
(Connection *)(shmem + shmem_offset),
{elem_per_block, connectivity.shape[1]}
);
shmem_offset += round_up_to_multiple_of_128(shr_connectivity.size() * sizeof(Connection));
uint32_t nodes_per_elem = el.num_nodes();
nd::view< uint32_t, 2 > shr_node_ids(
(uint32_t *)(shmem + shmem_offset),
{elem_per_block, nodes_per_elem}
);
shmem_offset += round_up_to_multiple_of_128(shr_node_ids.size() * sizeof(uint32_t));
nd::view< grad_type, 2 > shr_du_dxi_q(
(grad_type *)(shmem + shmem_offset),
{elem_per_block, qpts_per_elem}
);
shmem_offset += round_up_to_multiple_of_128(shr_du_dxi_q.size() * sizeof(grad_type));
nd::view< double, 2 > shr_u_e(
(double *)(shmem + shmem_offset),
{elem_per_block, nodes_per_elem}
);
shmem_offset += round_up_to_multiple_of_128(shr_u_e.size() * sizeof(double));
nd::view< double, n > shr_shape_fn_grads(
(double *)(shmem + shmem_offset),
shape_functions.shape
);
shmem_offset += round_up_to_multiple_of_128(shr_shape_fn_grads.size() * sizeof(double));
nd::view< double, 2 > shr_scratch(
(double *)(shmem + shmem_offset),
{elem_per_block, scratch_size}
);
int local_tid = threadIdx.x + blockDim.x * threadIdx.y;
int local_stride = blockDim.x * blockDim.y;
int block_tid = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
int block_stride = blockDim.x * blockDim.y * blockDim.z;
// 0. load shape function evaluations into shared memory
for (int i = block_tid; i < shr_shape_fn_grads.size(); i += block_stride) {
shr_shape_fn_grads[i] = shape_functions[i];
}
// 1. then figure out which element we're integrating and load its connectivity info
uint32_t e = blockIdx.x * blockDim.z + threadIdx.z;
if (e < elements.shape[0]) {
uint32_t elem_id = elements[e];
uint32_t local_elem_id = threadIdx.z;
for (int i = local_tid; i < shr_connectivity.shape[1]; i += local_stride) {
shr_connectivity(local_elem_id, i) = connectivity(elem_id, i);
}
__syncthreads();
// 2. from the connectivity, figure out which nodes belong to this element
if (local_tid == 0) {
el.indices(u_offsets, &shr_connectivity(local_elem_id, 0), &shr_node_ids(local_elem_id, 0));
}
__syncthreads();
int num_components = u.shape[1];
for (int c = 0; c < num_components; c++) {
// 3. load the nodal values for this element
for (int i = local_tid; i < nodes_per_elem; i += local_stride) {
shr_u_e(local_elem_id, i) = u(shr_node_ids(local_elem_id, i), c);
}
__syncthreads();
// 4. interpolate the quadrature point values for this element
el.cuda_gradient(
shr_du_dxi_q(local_elem_id),
shr_u_e(local_elem_id),
shr_shape_fn_grads,
&shr_scratch(local_elem_id, 0)
);
// 5. write those quadrature values out to global memory
for (int q = local_tid; q < qpts_per_elem; q += local_stride) {
du_dxi_q(qpts_per_elem * e + q, c) = shr_du_dxi_q(local_elem_id, q);
}
}
}
This modification doesn't require significant changes. The main difference is that the shared memory arrays now have one extra dimension that each thread indexes by its local element id.
Great, now the kernel definition is able to handle multiple elements per block, but how many elements should each thread block process?
Intuitively, we should pick enough elements per block such that the warps in that thread block are utilized efficiently. This is true when the blocksize nodes_per_elem * elems_per_block
is a multiple of 32, but that isn't a necessary condition. Also, recall that processing more elements per block, requires more shared memory. So, after a point, increasing the number of elements / block ends up using enough shared memory to hurt performance through decreased occupancy.
From those criteria (picking the smallest number that maximizes expected active threads / warp), we could guess some parameter values for each kind of element, but instead I just set up a performance benchmark and exhaustively tried out different numbers of elements/block (up to a maximum of 32 for the 2D elements and up to 16 for the 3D elements).
The follow sections tabulate data from a performance benchmark involving an unstructured meshes ~1,000,000 elements of the specified geometry.
The tests were performed with as many quadrature points as nodes in the respective element, and the listed values are the timings (in microseconds) as measured on a NVIDIA GV100 GPU.
The labels "scalar" and "vector" indicate which timings are from interpolation of scalar-valued and vector-valued (with 2 components for tri/quad and 3 components for tet/hex) fields, respectively.
elems/block | p=1 scalar | p=1 vector | p=2 scalar | p=2 vector | p=3 scalar | p=3 vector |
---|---|---|---|---|---|---|
1 | 1827.86 | 1650.38 | 1657.97 | 1842.93 | 2519.34 | 3418.98 |
2 | 928.676 | 841.768 | 846.261 | 1059.29 | 1197.07 | 1854.88 |
3 | 630.497 | 576.831 | 580.467 | 969.75 | 903.116 | 1626.01 |
4 | 481.159 | 500.158 | 461.676 | 913.121 | 928.326 | 1632.54 |
5 | 392.553 | 503.027 | 412.534 | 911.477 | 878.766 | 1600.71 |
6 | 333.971 | 477.727 | 468.3 | 889.149 | 819.912 | 1547.96 |
7 | 326.042 | 480.17 | 447.254 | 899.226 | 892.56 | 1528.68 |
8 | 262.445 | 470.527 | 415.596 | 878.654 | 841.106 | 1476.84 |
9 | 254.571 | 478.638 | 411.888 | 895.126 | 834.023 | 1511.27 |
10 | 243.835 | 473.16 | 393.825 | 882.801 | 861.805 | 1489.36 |
11 | 287.049 | 482.172 | 438.747 | 899.418 | 836.573 | 1504.07 |
12 | 270.981 | 465.057 | 424.349 | 883.528 | 811.112 | 1485.53 |
13 | 265.044 | 475.665 | 411.909 | 893.59 | 851.768 | 1502.35 |
14 | 256.67 | 469.396 | 404.22 | 880.548 | 837.254 | 1485.52 |
15 | 254.173 | 470.795 | 397.545 | 891.093 | 818.696 | 1495.94 |
16 | 247.275 | 459.956 | 384.163 | 888.107 | 784.377 | 1480.86 |
17 | 251.62 | 470.28 | 416.948 | 892.995 | 871.313 | 1520.84 |
18 | 244.949 | 467.163 | 404.171 | 878.305 | 839.328 | 1491.74 |
19 | 256.185 | 486.186 | 402.933 | 889.179 | 825.923 | 1505.27 |
20 | 252.626 | 478.833 | 392.897 | 881.011 | 852.22 | 1521.06 |
21 | 252.091 | 486.237 | 389.807 | 885.707 | 852.066 | 1530.15 |
22 | 249.021 | 468.597 | 406.985 | 874.233 | 836.657 | 1515.23 |
23 | 249.692 | 471.702 | 406.936 | 883.057 | 834.185 | 1561.02 |
24 | 245.635 | 461.386 | 397.223 | 862.413 | 813.613 | 1534.48 |
25 | 247.615 | 473.833 | 398.911 | 873.469 | 813.435 | 1565.22 |
26 | 245.002 | 467.378 | 391.892 | 862.249 | 863.425 | 1559.67 |
27 | 246.895 | 469.076 | 423.608 | 883.015 | 859.184 | 1569.57 |
28 | 242.812 | 459.521 | 415.038 | 870.332 | 837.382 | 1553.05 |
29 | 243.912 | 467.67 | 412.989 | 878.295 | 832.407 | 1575.21 |
30 | 242.796 | 465.201 | 408.271 | 867.699 | 825.056 | 1565.11 |
31 | 244.91 | 467.52 | 407.235 | 876.127 | 818.293 | 1577.99 |
32 | 239.306 | 464.493 | 398.204 | 878.589 | 795.347 | 1567.85 |
elems/block | p=1 scalar | p=1 vector | p=2 scalar | p=2 vector | p=3 scalar | p=3 vector |
---|---|---|---|---|---|---|
1 | 1656.78 | 2218.67 | 1652.46 | 2321.88 | 1945.38 | 3257.68 |
2 | 845.11 | 1143.16 | 874.964 | 1447.73 | 1155.96 | 2370.33 |
3 | 597.446 | 815.569 | 655.854 | 1301.68 | 1319.27 | 2414.98 |
4 | 466.877 | 677.232 | 804.536 | 1349.37 | 1098.34 | 2273.14 |
5 | 407.264 | 619.296 | 693.071 | 1305.25 | 1211.96 | 2303.77 |
6 | 357.307 | 590.028 | 625.527 | 1278.11 | 1084.49 | 2247.2 |
7 | 335.737 | 579.683 | 574.287 | 1269.97 | 1181.23 | 2290.01 |
8 | 320.743 | 583.936 | 652.137 | 1270.68 | 1089.45 | 2184.34 |
9 | 424.108 | 630.525 | 625.074 | 1284.37 | 1192.98 | 2322.53 |
10 | 399.961 | 606.691 | 584.138 | 1265.55 | 1120.31 | 2286.34 |
11 | 375.348 | 593.815 | 645.648 | 1294.83 | 1180.93 | 2330.26 |
12 | 342.942 | 585.122 | 611.476 | 1254.76 | 1123.15 | 2307.1 |
13 | 337.111 | 581.674 | 595.632 | 1274.77 | 1209.03 | 2393.54 |
14 | 331.777 | 582.223 | 568.566 | 1256.25 | 1153.51 | 2370.01 |
15 | 337.931 | 576.605 | 640.121 | 1296.06 | 1285.81 | 2524.15 |
16 | 331.768 | 583.422 | 610.019 | 1250.67 | 1236.86 | 2487.09 |
17 | 359.151 | 599.496 | 601.05 | 1272.71 | 1233.75 | 2486.11 |
18 | 343.784 | 594.612 | 642.87 | 1290.28 | 1192.57 | 2470.38 |
19 | 341.156 | 586.575 | 626.814 | 1285.81 | 1359.22 | 2814.6 |
20 | 334.723 | 582.002 | 606.753 | 1256.02 | 1318.22 | 2787.66 |
21 | 343.475 | 581.999 | 596.031 | 1268.99 | 1310.35 | 2855.63 |
22 | 338.722 | 582.531 | 638.675 | 1306.55 | 1273.36 | 2831.69 |
23 | 341.062 | 578.47 | 628.988 | 1303.96 | 1268.08 | 2830.22 |
24 | 336.373 | 582.976 | 608.801 | 1271.78 | 1235.44 | 2785.61 |
25 | 344.9 | 591.105 | 686.101 | 1366.73 | 1618.81 | 3021.18 |
26 | 333.62 | 589.11 | 668.891 | 1344.01 | 1577.15 | 2978.97 |
27 | 338.581 | 585.48 | 661.15 | 1346.66 | 1558.62 | 2931.92 |
28 | 336.218 | 581.596 | 641.111 | 1311.83 | 1517.17 | 2900.92 |
29 | 340.686 | 581.868 | 651.779 | 1353.14 | 1603.42 | 2930.69 |
30 | 336.04 | 580.796 | 636.678 | 1337.77 | 1568.33 | 2918.19 |
31 | 340.078 | 578.632 | 633.525 | 1339.32 | 1538.87 | 2908.85 |
32 | 333.727 | 561.618 | 607.717 | 1315.42 | 1506.12 | 2885.17 |
elems/block | p=1 scalar | p=1 vector | p=2 scalar | p=2 vector | p=3 scalar | p=3 vector |
---|---|---|---|---|---|---|
1 | 1918.18 | 2579.59 | 3598.5 | 6195.63 | 18699.4 | 28302.8 |
2 | 956.533 | 1884.27 | 1906.29 | 4638.2 | 8870.11 | 15803.4 |
3 | 740.434 | 1734.03 | 1530.37 | 4319.47 | 6070.27 | 12413.5 |
4 | 612.491 | 1644.02 | 1416.06 | 4164.69 | 5095.22 | 11988.8 |
5 | 648.159 | 1680.53 | 1297.22 | 4165.55 | 4902.87 | 11987.2 |
6 | 593.74 | 1728.37 | 1238.6 | 4176.59 | 4556.27 | 11918.2 |
7 | 550.643 | 1726.98 | 1289.43 | 4142.84 | 4430.8 | 11620 |
8 | 543.714 | 1731.24 | 1190.69 | 4145.25 | 4246.14 | 11527.8 |
9 | 558.379 | 1664.43 | 1192.71 | 4417.58 | 4361.39 | 11920.3 |
10 | 560.016 | 1717.08 | 1222.56 | 4164.6 | 4251.01 | 11878.3 |
11 | 551.588 | 1717.03 | 1201.95 | 4299.83 | 4214.91 | 11852.8 |
12 | 528.626 | 1701.24 | 1166.93 | 4529.63 | 4291.91 | 12215.7 |
13 | 534.903 | 1723.37 | 1226.06 | 4292.24 | 4278.57 | 12079.4 |
14 | 532.193 | 1740.74 | 1200.94 | 4435.75 | 4224.92 | 12078.4 |
15 | 541.353 | 1778.26 | 1171.28 | 4633.82 | 4221.01 | 12217.9 |
16 | 532.543 | 1817.12 | 1139.31 | 4785.83 | 4177.31 | 12249.3 |
elems/block | p=1 scalar | p=1 vector | p=2 scalar | p=2 vector | p=3 scalar | p=3 vector |
---|---|---|---|---|---|---|
1 | 2973.04 | 6554.48 | 3589.08 | 11388.5 | 13286.9 | 36235.7 |
2 | 1728.23 | 4022.05 | 3522.07 | 11097.3 | 12991.4 | 34682.7 |
3 | 1346.82 | 3423.51 | 3536.49 | 10902.8 | 14004.4 | 35837.6 |
4 | 1171.54 | 3179.03 | 3477.84 | 10801.5 | 14363 | 37169.5 |
5 | 1466.2 | 3598.96 | 3865.85 | 11473.8 | 17600.5 | 43768.8 |
6 | 1293.77 | 3314.87 | 4127.49 | 11899.6 | 16425.1 | 42558.9 |
7 | 1262.31 | 3252.9 | 3778.38 | 11359.4 | 15396.6 | 42125.9 |
8 | 1173.44 | 3135.41 | 3504.4 | 11232.2 | 21478 | 56800.4 |
9 | 1323.42 | 3333.4 | 3993.72 | 11901.2 | 20393.1 | 54862.6 |
10 | 1263.75 | 3202.47 | 3805.68 | 11709.6 | 19505.7 | 53456.8 |
11 | 1255.96 | 3205.36 | 4870.32 | 13488.4 | 18882.5 | 52406.2 |
12 | 1187.45 | 3129.54 | 4709.11 | 13191.2 | 18342.1 | 51557.9 |
13 | 1275.19 | 3269.11 | 4538.37 | 12974.7 | 18026.9 | 50993.4 |
14 | 1236.69 | 3185.27 | 4426.16 | 13116.7 | * | * |
15 | 1224.62 | 3221.43 | 4328.54 | 12747.9 | * | * |
16 | 1182.91 | 3165.8 | 4164.77 | 12591.6 | * | * |
Note: the entries marked with * require more shared memory than the GV100 has available
The naive CUDA port of the CPU code was a good start, but that implementation didn't perform very well for the low-order elements (which are the most commonly-used case). Luckily, it didn't make much effort to modify the kernel to accept a parameter that controls how many elements to process in a single threadblock. However, introducing a new parameter meant that we also needed to provide appropriate values for that parameter in each situation. By exhaustively sweeping through different number of elements processed per thread block, we were able to select values that significantly improved performance (7.5x for linear triangles, 5x for linear quads, ...).