Skip to content

Commit

Permalink
ARIMA - Kalman loop rewrite: single megakernel instead of host loop (#…
Browse files Browse the repository at this point in the history
…4006)

This PR brings **speedups** of the order of **10x** to seasonal ARIMA. It replaces a legacy host loop based on cuBLAS batched operations and RAPIDS prims, with a custom kernel, reducing launch overheads and unnecessary reads and writes in global memory. On top of that, it paves the way for support for missing observations.

The PR introduces a set of prims in `linalg/block.cuh` to compute block-local linear algebra operations, and corresponding unit tests in `test/prims/linalg_block.cu`.

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Robert Maynard (https://github.com/robertmaynard)

URL: #4006
  • Loading branch information
Nyrio authored Jul 12, 2021
1 parent a09aa4c commit c9abba1
Show file tree
Hide file tree
Showing 5 changed files with 1,730 additions and 322 deletions.
9 changes: 3 additions & 6 deletions cpp/include/cuml/tsa/arima_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ struct ARIMAMemory {
T *params_mu, *params_ar, *params_ma, *params_sar, *params_sma, *params_sigma2, *Tparams_mu,
*Tparams_ar, *Tparams_ma, *Tparams_sar, *Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams,
*Z_dense, *R_dense, *T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense,
*ImT_inv_dense, *T_values, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *vs, *y_diff,
*loglike, *loglike_base, *loglike_pert, *x_pert, *F_buffer, *sumLogF_buffer, *sigma2_buffer,
*ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *vs, *y_diff, *loglike,
*loglike_base, *loglike_pert, *x_pert, *F_buffer, *sumLogF_buffer, *sigma2_buffer,
*I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense, *RQRs_dense, *Ps_dense;
T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches, **P_batches,
**alpha_batches, **ImT_batches, **ImT_inv_batches, **v_tmp_batches, **m_tmp_batches,
**K_batches, **TP_batches, **I_m_AxA_batches, **I_m_AxA_inv_batches, **Ts_batches,
**RQRs_batches, **Ps_batches;
int *T_col_index, *T_row_index, *ImT_inv_P, *ImT_inv_info, *I_m_AxA_P, *I_m_AxA_info;
int *ImT_inv_P, *ImT_inv_info, *I_m_AxA_P, *I_m_AxA_info;

size_t size;

Expand Down Expand Up @@ -281,9 +281,6 @@ struct ARIMAMemory {
append_buffer<assign>(ImT_inv_batches, batch_size);
append_buffer<assign>(ImT_inv_P, r * batch_size);
append_buffer<assign>(ImT_inv_info, batch_size);
append_buffer<assign>(T_values, rd * rd * batch_size);
append_buffer<assign>(T_col_index, rd * rd);
append_buffer<assign>(T_row_index, rd + 1);
append_buffer<assign>(v_tmp_dense, rd * batch_size);
append_buffer<assign>(v_tmp_batches, batch_size);
append_buffer<assign>(m_tmp_dense, rd * rd * batch_size);
Expand Down
Loading

0 comments on commit c9abba1

Please sign in to comment.