diff --git a/cpp/bench/sg/arima_loglikelihood.cu b/cpp/bench/sg/arima_loglikelihood.cu index 1cd8929922..9e0f1b3754 100644 --- a/cpp/bench/sg/arima_loglikelihood.cu +++ b/cpp/bench/sg/arima_loglikelihood.cu @@ -79,6 +79,7 @@ class ArimaLoglikelihood : public TsFixtureRandom { batched_loglike(*this->handle, arima_mem, this->data.X.data(), + nullptr, this->params.batch_size, this->params.n_obs, order, @@ -122,11 +123,11 @@ std::vector getInputs() struct std::vector out; ArimaParams p; p.data.seed = 12345ULL; - std::vector list_order = {{1, 1, 1, 0, 0, 0, 0, 0}, - {1, 1, 1, 1, 1, 1, 4, 0}, - {1, 1, 1, 1, 1, 1, 12, 0}, - {1, 1, 1, 1, 1, 1, 24, 0}, - {1, 1, 1, 1, 1, 1, 52, 0}}; + std::vector list_order = {{1, 1, 1, 0, 0, 0, 0, 0, 0}, + {1, 1, 1, 1, 1, 1, 4, 0, 0}, + {1, 1, 1, 1, 1, 1, 12, 0, 0}, + {1, 1, 1, 1, 1, 1, 24, 0, 0}, + {1, 1, 1, 1, 1, 1, 52, 0, 0}}; std::vector list_batch_size = {10, 100, 1000, 10000}; std::vector list_n_obs = {200, 500, 1000}; for (auto& order : list_order) { diff --git a/cpp/include/cuml/tsa/arima_common.h b/cpp/include/cuml/tsa/arima_common.h index 2ed9da31e2..1f4f3554a3 100644 --- a/cpp/include/cuml/tsa/arima_common.h +++ b/cpp/include/cuml/tsa/arima_common.h @@ -37,15 +37,16 @@ struct ARIMAOrder { int P; // Seasonal order int D; int Q; - int s; // Seasonal period - int k; // Fit intercept? + int s; // Seasonal period + int k; // Fit intercept? + int n_exog; // Number of exogenous regressors inline int n_diff() const { return d + s * D; } inline int n_phi() const { return p + s * P; } inline int n_theta() const { return q + s * Q; } inline int r() const { return std::max(n_phi(), n_theta() + 1); } inline int rd() const { return n_diff() + r(); } - inline int complexity() const { return p + P + q + Q + k + 1; } + inline int complexity() const { return p + P + q + Q + k + n_exog + 1; } inline bool need_diff() const { return static_cast(d + D); } }; @@ -58,6 +59,7 @@ struct ARIMAOrder { template struct ARIMAParams { DataT* mu = nullptr; + DataT* beta = nullptr; DataT* ar = nullptr; DataT* ma = nullptr; DataT* sar = nullptr; @@ -77,6 +79,8 @@ struct ARIMAParams { { rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource(); if (order.k && !tr) mu = (DataT*)rmm_alloc->allocate(batch_size * sizeof(DataT), stream); + if (order.n_exog && !tr) + beta = (DataT*)rmm_alloc->allocate(order.n_exog * batch_size * sizeof(DataT), stream); if (order.p) ar = (DataT*)rmm_alloc->allocate(order.p * batch_size * sizeof(DataT), stream); if (order.q) ma = (DataT*)rmm_alloc->allocate(order.q * batch_size * sizeof(DataT), stream); if (order.P) sar = (DataT*)rmm_alloc->allocate(order.P * batch_size * sizeof(DataT), stream); @@ -97,6 +101,8 @@ struct ARIMAParams { { rmm::mr::device_memory_resource* rmm_alloc = rmm::mr::get_current_device_resource(); if (order.k && !tr) rmm_alloc->deallocate(mu, batch_size * sizeof(DataT), stream); + if (order.n_exog && !tr) + rmm_alloc->deallocate(beta, order.n_exog * batch_size * sizeof(DataT), stream); if (order.p) rmm_alloc->deallocate(ar, order.p * batch_size * sizeof(DataT), stream); if (order.q) rmm_alloc->deallocate(ma, order.q * batch_size * sizeof(DataT), stream); if (order.P) rmm_alloc->deallocate(sar, order.P * batch_size * sizeof(DataT), stream); @@ -118,7 +124,8 @@ struct ARIMAParams { int N = order.complexity(); auto counting = thrust::make_counting_iterator(0); // The device lambda can't capture structure members... - const DataT *_mu = mu, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma, *_sigma2 = sigma2; + const DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma, + *_sigma2 = sigma2; thrust::for_each( thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) { DataT* param = param_vec + bid * N; @@ -126,6 +133,10 @@ struct ARIMAParams { *param = _mu[bid]; param++; } + for (int i = 0; i < order.n_exog; i++) { + param[i] = _beta[order.n_exog * bid + i]; + } + param += order.n_exog; for (int ip = 0; ip < order.p; ip++) { param[ip] = _ar[order.p * bid + ip]; } @@ -160,7 +171,8 @@ struct ARIMAParams { int N = order.complexity(); auto counting = thrust::make_counting_iterator(0); // The device lambda can't capture structure members... - DataT *_mu = mu, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma, *_sigma2 = sigma2; + DataT *_mu = mu, *_beta = beta, *_ar = ar, *_ma = ma, *_sar = sar, *_sma = sma, + *_sigma2 = sigma2; thrust::for_each( thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) { const DataT* param = param_vec + bid * N; @@ -168,6 +180,10 @@ struct ARIMAParams { _mu[bid] = *param; param++; } + for (int i = 0; i < order.n_exog; i++) { + _beta[order.n_exog * bid + i] = param[i]; + } + param += order.n_exog; for (int ip = 0; ip < order.p; ip++) { _ar[order.p * bid + ip] = param[ip]; } @@ -197,11 +213,11 @@ struct ARIMAParams { */ template struct ARIMAMemory { - T *params_mu, *params_ar, *params_ma, *params_sar, *params_sma, *params_sigma2, *Tparams_mu, + T *params_mu, *params_beta, *params_ar, *params_ma, *params_sar, *params_sma, *params_sigma2, *Tparams_ar, *Tparams_ma, *Tparams_sar, *Tparams_sma, *Tparams_sigma2, *d_params, *d_Tparams, *Z_dense, *R_dense, *T_dense, *RQR_dense, *RQ_dense, *P_dense, *alpha_dense, *ImT_dense, - *ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *pred, *y_diff, *loglike, - *loglike_base, *loglike_pert, *x_pert, *I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense, + *ImT_inv_dense, *v_tmp_dense, *m_tmp_dense, *K_dense, *TP_dense, *pred, *y_diff, *exog_diff, + *loglike, *loglike_base, *loglike_pert, *x_pert, *I_m_AxA_dense, *I_m_AxA_inv_dense, *Ts_dense, *RQRs_dense, *Ps_dense; T **Z_batches, **R_batches, **T_batches, **RQR_batches, **RQ_batches, **P_batches, **alpha_batches, **ImT_batches, **ImT_inv_batches, **v_tmp_batches, **m_tmp_batches, @@ -236,13 +252,13 @@ struct ARIMAMemory { int n_diff = order.n_diff(); append_buffer(params_mu, order.k * batch_size); + append_buffer(params_beta, order.n_exog * batch_size); append_buffer(params_ar, order.p * batch_size); append_buffer(params_ma, order.q * batch_size); append_buffer(params_sar, order.P * batch_size); append_buffer(params_sma, order.Q * batch_size); append_buffer(params_sigma2, batch_size); - append_buffer(Tparams_mu, order.k * batch_size); append_buffer(Tparams_ar, order.p * batch_size); append_buffer(Tparams_ma, order.q * batch_size); append_buffer(Tparams_sar, order.P * batch_size); @@ -282,6 +298,7 @@ struct ARIMAMemory { append_buffer(pred, n_obs * batch_size); append_buffer(y_diff, n_obs * batch_size); + append_buffer(exog_diff, n_obs * order.n_exog * batch_size); append_buffer(loglike, batch_size); append_buffer(loglike_base, batch_size); append_buffer(loglike_pert, batch_size); diff --git a/cpp/include/cuml/tsa/batched_arima.hpp b/cpp/include/cuml/tsa/batched_arima.hpp index ad3c78a05c..7612c2fe34 100644 --- a/cpp/include/cuml/tsa/batched_arima.hpp +++ b/cpp/include/cuml/tsa/batched_arima.hpp @@ -90,6 +90,8 @@ void batched_diff(raft::handle_t& handle, * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) + * @param[in] d_exog Exogenous variables: shape = (n_obs, n_exog * batch_size) and + * expects column major data layout. (device) * @param[in] batch_size Number of time series * @param[in] n_obs Number of observations in a time series * @param[in] order ARIMA hyper-parameters @@ -101,16 +103,11 @@ void batched_diff(raft::handle_t& handle, * @param[in] method Whether to use sum-of-squares or Kalman filter * @param[in] truncate For CSS, start the sum-of-squares after a given * number of observations - * @param[in] fc_steps Number of steps to forecast - * @param[in] d_fc Array to store the forecast - * @param[in] level Confidence level for prediction intervals. 0 to - * skip the computation. Else 0 < level < 1 - * @param[out] d_lower Lower limit of the prediction interval - * @param[out] d_upper Upper limit of the prediction interval */ void batched_loglike(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -119,12 +116,7 @@ void batched_loglike(raft::handle_t& handle, bool trans = true, bool host_loglike = true, LoglikeMethod method = MLE, - int truncate = 0, - int fc_steps = 0, - double* d_fc = nullptr, - double level = 0, - double* d_lower = nullptr, - double* d_upper = nullptr); + int truncate = 0); /** * Compute the loglikelihood of the given parameter on the given time series @@ -137,6 +129,8 @@ void batched_loglike(raft::handle_t& handle, * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) + * @param[in] d_exog Exogenous variables: shape = (n_obs, n_exog * batch_size) and + * expects column major data layout. (device) * @param[in] batch_size Number of time series * @param[in] n_obs Number of observations in a time series * @param[in] order ARIMA hyper-parameters @@ -149,6 +143,8 @@ void batched_loglike(raft::handle_t& handle, * number of observations * @param[in] fc_steps Number of steps to forecast * @param[in] d_fc Array to store the forecast + * @param[in] d_exog_fut Future values of exogenous variables + * Shape (fc_steps, n_exog * batch_size) (col-major, device) * @param[in] level Confidence level for prediction intervals. 0 to * skip the computation. Else 0 < level < 1 * @param[out] d_lower Lower limit of the prediction interval @@ -157,20 +153,22 @@ void batched_loglike(raft::handle_t& handle, void batched_loglike(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, const ARIMAParams& params, double* loglike, - bool trans = true, - bool host_loglike = true, - LoglikeMethod method = MLE, - int truncate = 0, - int fc_steps = 0, - double* d_fc = nullptr, - double level = 0, - double* d_lower = nullptr, - double* d_upper = nullptr); + bool trans = true, + bool host_loglike = true, + LoglikeMethod method = MLE, + int truncate = 0, + int fc_steps = 0, + double* d_fc = nullptr, + const double* d_exog_fut = nullptr, + double level = 0, + double* d_lower = nullptr, + double* d_upper = nullptr); /** * Compute the gradient of the log-likelihood @@ -179,6 +177,8 @@ void batched_loglike(raft::handle_t& handle, * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) + * @param[in] d_exog Exogenous variables: shape = (n_obs, n_exog * batch_size) and + * expects column major data layout. (device) * @param[in] batch_size Number of time series * @param[in] n_obs Number of observations in a time series * @param[in] order ARIMA hyper-parameters @@ -193,6 +193,7 @@ void batched_loglike(raft::handle_t& handle, void batched_loglike_grad(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -211,6 +212,10 @@ void batched_loglike_grad(raft::handle_t& handle, * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Batched Time series to predict. * Shape: (num_samples, batch size) (device) + * @param[in] d_exog Exogenous variables. + * Shape = (n_obs, n_exog * batch_size) (device) + * @param[in] d_exog_fut Future values of exogenous variables + * Shape: (end - n_obs, batch_size) (device) * @param[in] batch_size Total number of batched time series * @param[in] n_obs Number of samples per time series * (all series must be identical) @@ -228,6 +233,8 @@ void batched_loglike_grad(raft::handle_t& handle, void predict(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, + const double* d_exog_fut, int batch_size, int n_obs, int start, @@ -247,6 +254,8 @@ void predict(raft::handle_t& handle, * @param[in] arima_mem Pre-allocated temporary memory * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) + * @param[in] d_exog Exogenous variables. + * Shape = (n_obs, n_exog * batch_size) (device) * @param[in] batch_size Total number of batched time series * @param[in] n_obs Number of samples per time series * (all series must be identical) @@ -260,6 +269,7 @@ void predict(raft::handle_t& handle, void information_criterion(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -274,6 +284,8 @@ void information_criterion(raft::handle_t& handle, * @param[in] params ARIMA parameters (device) * @param[in] d_y Series to fit: shape = (n_obs, batch_size) and * expects column major data layout. (device) + * @param[in] d_exog Exogenous variables. + * Shape = (n_obs, n_exog * batch_size) (device) * @param[in] batch_size Total number of batched time series * @param[in] n_obs Number of samples per time series * (all series must be identical) @@ -283,6 +295,7 @@ void information_criterion(raft::handle_t& handle, void estimate_x0(raft::handle_t& handle, ARIMAParams& params, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, diff --git a/cpp/include/cuml/tsa/batched_kalman.hpp b/cpp/include/cuml/tsa/batched_kalman.hpp index 4388dddee5..b9693b01bd 100644 --- a/cpp/include/cuml/tsa/batched_kalman.hpp +++ b/cpp/include/cuml/tsa/batched_kalman.hpp @@ -30,8 +30,10 @@ namespace ML { * * @param[in] handle cuML handle * @param[in] arima_mem Pre-allocated temporary memory - * @param[in] d_ys_b Batched time series + * @param[in] d_ys Batched time series * Shape (nobs, batch_size) (col-major, device) + * @param[in] d_exog Batched exogenous variables + * Shape (nobs, n_exog * batch_size) (col-major, device) * @param[in] nobs Number of samples per time series * @param[in] params ARIMA parameters (device) * @param[in] order ARIMA hyper-parameters @@ -41,6 +43,8 @@ namespace ML { * shape=(nobs-d-s*D, batch_size) (device) * @param[in] fc_steps Number of steps to forecast * @param[in] d_fc Array to store the forecast + * @param[in] d_exog_fut Future values of exogenous variables + * Shape (fc_steps, n_exog * batch_size) (col-major, device) * @param[in] level Confidence level for prediction intervals. 0 to * skip the computation. Else 0 < level < 1 * @param[out] d_lower Lower limit of the prediction interval @@ -48,18 +52,20 @@ namespace ML { */ void batched_kalman_filter(raft::handle_t& handle, const ARIMAMemory& arima_mem, - const double* d_ys_b, + const double* d_ys, + const double* d_exog, int nobs, const ARIMAParams& params, const ARIMAOrder& order, int batch_size, double* d_loglike, double* d_pred, - int fc_steps = 0, - double* d_fc = nullptr, - double level = 0, - double* d_lower = nullptr, - double* d_upper = nullptr); + int fc_steps = 0, + double* d_fc = nullptr, + const double* d_exog_fut = nullptr, + double level = 0, + double* d_lower = nullptr, + double* d_upper = nullptr); /** * Convenience function for batched "jones transform" used in ARIMA to ensure diff --git a/cpp/src/arima/batched_arima.cu b/cpp/src/arima/batched_arima.cu index a73891db6d..cb91831a7f 100644 --- a/cpp/src/arima/batched_arima.cu +++ b/cpp/src/arima/batched_arima.cu @@ -90,6 +90,8 @@ bool detect_missing(raft::handle_t& handle, const double* d_y, int n_elem) void predict(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, + const double* d_exog_fut, int batch_size, int n_obs, int start, @@ -105,29 +107,61 @@ void predict(raft::handle_t& handle, ML::PUSH_RANGE(__func__); const auto stream = handle.get_stream(); - bool diff = order.need_diff() && pre_diff && level == 0; + bool diff = order.need_diff() && pre_diff && level == 0; + int num_steps = std::max(end - n_obs, 0); // Prepare data int n_obs_kf; const double* d_y_kf; + const double* d_exog_kf; + const double* d_exog_fut_kf = d_exog_fut; ARIMAOrder order_after_prep = order; + rmm::device_uvector exog_fut_buffer(0, stream); if (diff) { n_obs_kf = n_obs - order.n_diff(); MLCommon::TimeSeries::prepare_data( arima_mem.y_diff, d_y, batch_size, n_obs, order.d, order.D, order.s, stream); + if (order.n_exog > 0) { + MLCommon::TimeSeries::prepare_data(arima_mem.exog_diff, + d_exog, + order.n_exog * batch_size, + n_obs, + order.d, + order.D, + order.s, + stream); + + if (num_steps > 0) { + exog_fut_buffer.resize(num_steps * order.n_exog * batch_size, stream); + + MLCommon::TimeSeries::prepare_future_data(exog_fut_buffer.data(), + d_exog, + d_exog_fut, + order.n_exog * batch_size, + n_obs, + num_steps, + order.d, + order.D, + order.s, + stream); + + d_exog_fut_kf = exog_fut_buffer.data(); + } + } order_after_prep.d = 0; order_after_prep.D = 0; - d_y_kf = arima_mem.y_diff; + d_y_kf = arima_mem.y_diff; + d_exog_kf = arima_mem.exog_diff; } else { - n_obs_kf = n_obs; - d_y_kf = d_y; + n_obs_kf = n_obs; + d_y_kf = d_y; + d_exog_kf = d_exog; } double* d_pred = arima_mem.pred; // Create temporary array for the forecasts - int num_steps = std::max(end - n_obs, 0); rmm::device_uvector fc_buffer(num_steps * batch_size, stream); double* d_y_fc = fc_buffer.data(); @@ -137,6 +171,7 @@ void predict(raft::handle_t& handle, batched_loglike(handle, arima_mem, d_y_kf, + d_exog_kf, batch_size, n_obs_kf, order_after_prep, @@ -148,6 +183,7 @@ void predict(raft::handle_t& handle, 0, num_steps, d_y_fc, + d_exog_fut_kf, level, d_lower, d_upper); @@ -364,6 +400,7 @@ void conditional_sum_of_squares(raft::handle_t& handle, void batched_loglike(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -375,6 +412,7 @@ void batched_loglike(raft::handle_t& handle, int truncate, int fc_steps, double* d_fc, + const double* d_exog_fut, double level, double* d_lower, double* d_upper) @@ -385,7 +423,8 @@ void batched_loglike(raft::handle_t& handle, double* d_pred = arima_mem.pred; - ARIMAParams Tparams = {arima_mem.Tparams_mu, + ARIMAParams Tparams = {params.mu, + params.beta, arima_mem.Tparams_ar, arima_mem.Tparams_ma, arima_mem.Tparams_sar, @@ -400,11 +439,8 @@ void batched_loglike(raft::handle_t& handle, if (trans) { MLCommon::TimeSeries::batched_jones_transform( order, batch_size, false, params, Tparams, stream); - - Tparams.mu = params.mu; } else { // non-transformed case: just use original parameters - Tparams.mu = params.mu; Tparams.ar = params.ar; Tparams.ma = params.ma; Tparams.sar = params.sar; @@ -418,6 +454,7 @@ void batched_loglike(raft::handle_t& handle, batched_kalman_filter(handle, arima_mem, d_y, + d_exog, n_obs, Tparams, order, @@ -426,6 +463,7 @@ void batched_loglike(raft::handle_t& handle, d_pred, fc_steps, d_fc, + d_exog_fut, level, d_lower, d_upper); @@ -441,6 +479,7 @@ void batched_loglike(raft::handle_t& handle, void batched_loglike(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -449,12 +488,7 @@ void batched_loglike(raft::handle_t& handle, bool trans, bool host_loglike, LoglikeMethod method, - int truncate, - int fc_steps, - double* d_fc, - double level, - double* d_lower, - double* d_upper) + int truncate) { ML::PUSH_RANGE(__func__); @@ -462,6 +496,7 @@ void batched_loglike(raft::handle_t& handle, auto stream = handle.get_stream(); ARIMAParams params = {arima_mem.params_mu, + arima_mem.params_beta, arima_mem.params_ar, arima_mem.params_ma, arima_mem.params_sar, @@ -473,6 +508,7 @@ void batched_loglike(raft::handle_t& handle, batched_loglike(handle, arima_mem, d_y, + d_exog, batch_size, n_obs, order, @@ -481,12 +517,7 @@ void batched_loglike(raft::handle_t& handle, trans, host_loglike, method, - truncate, - fc_steps, - d_fc, - level, - d_lower, - d_upper); + truncate); ML::POP_RANGE(); } @@ -494,6 +525,7 @@ void batched_loglike(raft::handle_t& handle, void batched_loglike_grad(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -520,6 +552,7 @@ void batched_loglike_grad(raft::handle_t& handle, batched_loglike(handle, arima_mem, d_y, + d_exog, batch_size, n_obs, order, @@ -541,6 +574,7 @@ void batched_loglike_grad(raft::handle_t& handle, batched_loglike(handle, arima_mem, d_y, + d_exog, batch_size, n_obs, order, @@ -570,6 +604,7 @@ void batched_loglike_grad(raft::handle_t& handle, void information_criterion(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -582,7 +617,7 @@ void information_criterion(raft::handle_t& handle, /* Compute log-likelihood in d_ic */ batched_loglike( - handle, arima_mem, d_y, batch_size, n_obs, order, params, d_ic, false, false, MLE); + handle, arima_mem, d_y, d_exog, batch_size, n_obs, order, params, d_ic, false, false, MLE); /* Compute information criterion from log-likelihood and base term */ MLCommon::Metrics::Batched::information_criterion( @@ -607,8 +642,8 @@ void information_criterion(raft::handle_t& handle, template DI bool test_invparams(const double* params, int pq) { - double new_params[4]; - double tmp[4]; + double new_params[8]; + double tmp[8]; constexpr double coef = isAr ? 1 : -1; @@ -820,9 +855,77 @@ void _arma_least_squares(raft::handle_t& handle, */ void _start_params(raft::handle_t& handle, ARIMAParams& params, - const MLCommon::LinAlg::Batched::Matrix& bm_y, + MLCommon::LinAlg::Batched::Matrix& bm_y, + const MLCommon::LinAlg::Batched::Matrix& bm_exog, const ARIMAOrder& order) { + int batch_size = bm_exog.batches(); + cudaStream_t stream = bm_exog.stream(); + + // Estimate exog coefficients and subtract component to endog. + // Exog coefficients are estimated by fitting a linear regression with X=exog, y=endog + if (order.n_exog > 0) { + // In most cases, the system will be overdetermined and we can use gels + if (bm_exog.shape().first > static_cast(order.n_exog)) { + // Make a copy of the exogenous series for in-place gels + MLCommon::LinAlg::Batched::Matrix bm_exog_copy(bm_exog); + // Make a copy of the endogenous series for in-place gels + MLCommon::LinAlg::Batched::Matrix bm_y_copy(bm_y); + + // Least-squares solution of overdetermined system + rmm::device_uvector info(batch_size, stream); + b_gels(bm_exog_copy, bm_y_copy, info.data()); + + // Make a batched matrix around the exogenous coefficients + rmm::device_uvector beta_pointers(batch_size, stream); + MLCommon::LinAlg::Batched::Matrix bm_exog_coef(order.n_exog, + 1, + batch_size, + bm_exog.cublasHandle(), + beta_pointers.data(), + params.beta, + stream, + false); + + // Copy the solution of the system to the parameters array + b_2dcopy(bm_y_copy, bm_exog_coef, 0, 0, order.n_exog, 1); + + // Set parameters to zero when solving was not successful + auto counting = thrust::make_counting_iterator(0); + int* devInfoArray = info.data(); + double* d_exog_coef = bm_exog_coef.raw_data(); + const int& n_exog = order.n_exog; + thrust::for_each( + thrust::cuda::par.on(stream), counting, counting + batch_size, [=] __device__(int bid) { + if (devInfoArray[bid] > 0) { + for (int i = 0; i < n_exog; i++) { + d_exog_coef[bid * n_exog + i] = 0.0; + } + } + }); + + // Compute exogenous component and store the result in bm_y_copy + b_gemm(false, + false, + bm_exog.shape().first, + 1, + bm_exog.shape().second, + 1.0, + bm_exog, + bm_exog_coef, + 0.0, + bm_y_copy); + + // Subtract exogenous component to endogenous variable + b_aA_op_B(bm_y, bm_y_copy, bm_y, [] __device__(double a, double b) { return a - b; }); + } + // In other cases, we initialize to zero + else { + CUDA_CHECK( + cudaMemsetAsync(params.beta, 0, order.n_exog * batch_size * sizeof(double), stream)); + } + } + // Estimate an ARMA fit without seasonality if (order.p + order.q + order.k) _arma_least_squares(handle, @@ -853,6 +956,7 @@ void _start_params(raft::handle_t& handle, void estimate_x0(raft::handle_t& handle, ARIMAParams& params, const double* d_y, + const double* d_exog, int batch_size, int n_obs, const ARIMAOrder& order, @@ -863,6 +967,10 @@ void estimate_x0(raft::handle_t& handle, auto stream = handle_impl.get_stream(); auto cublas_handle = handle_impl.get_cublas_handle(); + /// TODO: solve exogenous coefficients with only valid rows instead of interpolation? + // Pros: better coefficients + // Cons: harder to test, a bit more complicated + // Least squares can't deal with missing values: create copy with naive // replacements for missing values const double* d_y_no_missing; @@ -883,8 +991,22 @@ void estimate_x0(raft::handle_t& handle, MLCommon::TimeSeries::prepare_data( bm_yd.raw_data(), d_y_no_missing, batch_size, n_obs, order.d, order.D, order.s, stream); + // Difference or copy exog + MLCommon::LinAlg::Batched::Matrix bm_exog_diff( + n_obs - order.d - order.s * order.D, order.n_exog, batch_size, cublas_handle, stream, false); + if (order.n_exog > 0) { + MLCommon::TimeSeries::prepare_data(bm_exog_diff.raw_data(), + d_exog, + order.n_exog * batch_size, + n_obs, + order.d, + order.D, + order.s, + stream); + } + // Do the computation of the initial parameters - _start_params(handle, params, bm_yd, order); + _start_params(handle, params, bm_yd, bm_exog_diff, order); ML::POP_RANGE(); } diff --git a/cpp/src/arima/batched_kalman.cu b/cpp/src/arima/batched_kalman.cu index 572f0abb09..1c189dde13 100644 --- a/cpp/src/arima/batched_kalman.cu +++ b/cpp/src/arima/batched_kalman.cu @@ -103,24 +103,26 @@ DI void numerical_stability(double* A) * Kalman loop kernel. Each thread computes kalman filter for a single series * and stores relevant matrices in registers. * - * @tparam r Dimension of the state vector - * @param[in] ys Batched time series - * @param[in] nobs Number of observation per series - * @param[in] T Batched transition matrix. (r x r) - * @param[in] Z Batched "design" vector (1 x r) - * @param[in] RQR Batched R*Q*R' (r x r) - * @param[in] P Batched P (r x r) - * @param[in] alpha Batched state vector (r x 1) - * @param[in] intercept Do we fit an intercept? - * @param[in] d_mu Batched intercept (1) - * @param[in] batch_size Batch size - * @param[out] d_pred Predictions (nobs) - * @param[out] d_loglike Log-likelihood (1) - * @param[in] n_diff d + s*D - * @param[in] fc_steps Number of steps to forecast - * @param[out] d_fc Array to store the forecast - * @param[in] conf_int Whether to compute confidence intervals - * @param[out] d_F_fc Batched variance of forecast errors (fc_steps) + * @tparam rd Dimension of the state vector + * @param[in] ys Batched time series + * @param[in] nobs Number of observation per series + * @param[in] T Batched transition matrix. (r x r) + * @param[in] Z Batched "design" vector (1 x r) + * @param[in] RQR Batched R*Q*R' (r x r) + * @param[in] P Batched P (r x r) + * @param[in] alpha Batched state vector (r x 1) + * @param[in] intercept Do we fit an intercept? + * @param[in] d_mu Batched intercept (1) + * @param[in] batch_size Batch size + * @param[in] d_obs_inter Observation intercept + * @param[in] d_obs_inter_fut Observation intercept for forecasts + * @param[out] d_pred Predictions (nobs) + * @param[out] d_loglike Log-likelihood (1) + * @param[in] n_diff d + s*D + * @param[in] fc_steps Number of steps to forecast + * @param[out] d_fc Array to store the forecast + * @param[in] conf_int Whether to compute confidence intervals + * @param[out] d_F_fc Batched variance of forecast errors (fc_steps) */ template __global__ void batched_kalman_loop_kernel(const double* ys, @@ -133,6 +135,8 @@ __global__ void batched_kalman_loop_kernel(const double* ys, bool intercept, const double* d_mu, int batch_size, + const double* d_obs_inter, + const double* d_obs_inter_fut, double* d_pred, double* d_loglike, int n_diff, @@ -181,11 +185,11 @@ __global__ void batched_kalman_loop_kernel(const double* ys, bool missing; { // 1. v = y - Z*alpha - double pred; + double pred = 0.0; + if (d_obs_inter != nullptr) { pred += d_obs_inter[bid * nobs + it]; } if (n_diff == 0) - pred = l_alpha[0]; + pred += l_alpha[0]; else { - pred = 0.0; for (int i = 0; i < rd; i++) { pred += l_alpha[i] * l_Z[i]; } @@ -286,15 +290,16 @@ __global__ void batched_kalman_loop_kernel(const double* ys, double* b_fc = fc_steps ? d_fc + bid * fc_steps : nullptr; double* b_F_fc = conf_int ? d_F_fc + bid * fc_steps : nullptr; for (int it = 0; it < fc_steps; it++) { + double pred = 0.0; + if (d_obs_inter_fut != nullptr) { pred += d_obs_inter_fut[bid * fc_steps + it]; } if (n_diff == 0) - b_fc[it] = l_alpha[0]; + pred += l_alpha[0]; else { - double pred = 0.0; for (int i = 0; i < rd; i++) { pred += l_alpha[i] * l_Z[i]; } - b_fc[it] = pred; } + b_fc[it] = pred; // alpha = T*alpha + c Mv_l(l_T, l_alpha, l_tmp); @@ -349,29 +354,31 @@ union KalmanLoopSharedMemory { /** * Kalman loop kernel. Each block computes kalman filter for a single series. * - * @tparam GemmPolicy Execution policy for GEMM - * @tparam GemvPolicy Execution policy for GEMV - * @tparam CovPolicy Execution policy for the covariance stability operation - * @param[in] d_ys Batched time series - * @param[in] batch_size Batch size - * @param[in] n_obs Number of observation per series - * @param[in] d_T Batched transition matrix. (r x r) - * @param[in] d_Z Batched "design" vector (1 x r) - * @param[in] d_RQR Batched R*Q*R' (r x r) - * @param[in] d_P Batched P (r x r) - * @param[in] d_alpha Batched state vector (r x 1) - * @param[in] d_m_tmp Batched temporary matrix (r x r) - * @param[in] d_TP Batched temporary matrix to store TP (r x r) - * @param[in] intercept Do we fit an intercept? - * @param[in] d_mu Batched intercept (1) - * @param[in] rd State vector dimension - * @param[out] d_pred Predictions (nobs) - * @param[out] d_loglike Log-likelihood (1) - * @param[in] n_diff d + s*D - * @param[in] fc_steps Number of steps to forecast - * @param[out] d_fc Array to store the forecast - * @param[in] conf_int Whether to compute confidence intervals - * @param[out] d_F_fc Batched variance of forecast errors (fc_steps) + * @tparam GemmPolicy Execution policy for GEMM + * @tparam GemvPolicy Execution policy for GEMV + * @tparam CovPolicy Execution policy for the covariance stability operation + * @param[in] d_ys Batched time series + * @param[in] batch_size Batch size + * @param[in] n_obs Number of observation per series + * @param[in] d_T Batched transition matrix. (r x r) + * @param[in] d_Z Batched "design" vector (1 x r) + * @param[in] d_RQR Batched R*Q*R' (r x r) + * @param[in] d_P Batched P (r x r) + * @param[in] d_alpha Batched state vector (r x 1) + * @param[in] d_m_tmp Batched temporary matrix (r x r) + * @param[in] d_TP Batched temporary matrix to store TP (r x r) + * @param[in] intercept Do we fit an intercept? + * @param[in] d_mu Batched intercept (1) + * @param[in] rd State vector dimension + * @param[in] d_obs_inter Observation intercept + * @param[in] d_obs_inter_fut Observation intercept for forecasts + * @param[out] d_pred Predictions (nobs) + * @param[out] d_loglike Log-likelihood (1) + * @param[in] n_diff d + s*D + * @param[in] fc_steps Number of steps to forecast + * @param[out] d_fc Array to store the forecast + * @param[in] conf_int Whether to compute confidence intervals + * @param[out] d_F_fc Batched variance of forecast errors (fc_steps) */ template __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, @@ -387,6 +394,8 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, bool intercept, const double* d_mu, int rd, + const double* d_obs_inter, + const double* d_obs_inter_fut, double* d_pred, double* d_loglike, int n_diff, @@ -420,22 +429,24 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, double sum_logFs = 0.0; double ll_s2 = 0.0; int n_obs_ll = 0; - int it; + int it = 0; /* Skip missing observations at the start */ - { - double pred0; - if (n_diff == 0) { - pred0 = shared_alpha[0]; - } else { - pred0 = 0.0; - pred0 += MLCommon::LinAlg::_block_dot( - rd, shared_Z, shared_alpha, shared_mem.reduction_storage); - __syncthreads(); // necessary to reuse shared memory - } + if (d_obs_inter == nullptr) { + { + double pred0; + if (n_diff == 0) { + pred0 = shared_alpha[0]; + } else { + pred0 = 0.0; + pred0 += MLCommon::LinAlg::_block_dot( + rd, shared_Z, shared_alpha, shared_mem.reduction_storage); + __syncthreads(); // necessary to reuse shared memory + } - for (it = 0; it < n_obs && isnan(d_ys[bid * n_obs + it]); it++) { - if (threadIdx.x == 0) d_pred[bid * n_obs + it] = pred0; + for (; it < n_obs && isnan(d_ys[bid * n_obs + it]); it++) { + if (threadIdx.x == 0) d_pred[bid * n_obs + it] = pred0; + } } } @@ -444,13 +455,13 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, double vt, _F; bool missing; { - // 1. pred = Z*alpha + // 1. pred = Z*alpha + obs_intercept // v = y - pred - double pred; + double pred = 0.0; + if (d_obs_inter != nullptr) { pred += d_obs_inter[bid * n_obs + it]; } if (n_diff == 0) { - pred = shared_alpha[0]; + pred += shared_alpha[0]; } else { - pred = 0.0; pred += MLCommon::LinAlg::_block_dot( rd, shared_Z, shared_alpha, shared_mem.reduction_storage); __syncthreads(); // necessary to reuse shared memory @@ -561,12 +572,13 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, /* Forecast */ for (int it = 0; it < fc_steps; it++) { - // pred = Z * alpha - double pred; + // pred = Z * alpha + obs_intercept + double pred = 0.0; + if (d_obs_inter_fut != nullptr) { pred += d_obs_inter_fut[bid * fc_steps + it]; } if (n_diff == 0) { - pred = shared_alpha[0]; + pred += shared_alpha[0]; } else { - pred = MLCommon::LinAlg::_block_dot( + pred += MLCommon::LinAlg::_block_dot( rd, shared_Z, shared_alpha, shared_mem.reduction_storage); __syncthreads(); // necessary to reuse shared memory } @@ -647,24 +659,26 @@ __global__ void _batched_kalman_device_loop_large_kernel(const double* d_ys, /** * Kalman loop for large matrices (r > 8). * - * @param[in] arima_mem Pre-allocated temporary memory - * @param[in] d_ys Batched time series - * @param[in] nobs Number of observation per series - * @param[in] T Batched transition matrix. (r x r) - * @param[in] Z Batched "design" vector (1 x r) - * @param[in] RQR Batched R*Q*R' (r x r) - * @param[in] P Batched P (r x r) - * @param[in] alpha Batched state vector (r x 1) - * @param[in] intercept Do we fit an intercept? - * @param[in] d_mu Batched intercept (1) - * @param[in] rd Dimension of the state vector - * @param[out] d_pred Predictions (nobs) - * @param[out] d_loglike Log-likelihood (1) - * @param[in] n_diff d + s*D - * @param[in] fc_steps Number of steps to forecast - * @param[out] d_fc Array to store the forecast - * @param[in] conf_int Whether to compute confidence intervals - * @param[out] d_F_fc Batched variance of forecast errors (fc_steps) + * @param[in] arima_mem Pre-allocated temporary memory + * @param[in] d_ys Batched time series + * @param[in] nobs Number of observation per series + * @param[in] T Batched transition matrix. (r x r) + * @param[in] Z Batched "design" vector (1 x r) + * @param[in] RQR Batched R*Q*R' (r x r) + * @param[in] P Batched P (r x r) + * @param[in] alpha Batched state vector (r x 1) + * @param[in] intercept Do we fit an intercept? + * @param[in] d_mu Batched intercept (1) + * @param[in] rd Dimension of the state vector + * @param[in] d_obs_inter Observation intercept + * @param[in] d_obs_inter_fut Observation intercept for forecasts + * @param[out] d_pred Predictions (nobs) + * @param[out] d_loglike Log-likelihood (1) + * @param[in] n_diff d + s*D + * @param[in] fc_steps Number of steps to forecast + * @param[out] d_fc Array to store the forecast + * @param[in] conf_int Whether to compute confidence intervals + * @param[out] d_F_fc Batched variance of forecast errors (fc_steps) */ template void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, @@ -678,6 +692,8 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, bool intercept, const double* d_mu, int rd, + const double* d_obs_inter, + const double* d_obs_inter_fut, double* d_pred, double* d_loglike, int n_diff, @@ -723,6 +739,8 @@ void _batched_kalman_device_loop_large(const ARIMAMemory& arima_mem, intercept, d_mu, rd, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -745,6 +763,8 @@ void batched_kalman_loop(raft::handle_t& handle, bool intercept, const double* d_mu, const ARIMAOrder& order, + const double* d_obs_inter, + const double* d_obs_inter_fut, double* d_pred, double* d_loglike, int fc_steps = 0, @@ -772,6 +792,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -792,6 +814,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -812,6 +836,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -832,6 +858,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -852,6 +880,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -872,6 +902,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -892,6 +924,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -912,6 +946,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, batch_size, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -941,6 +977,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, rd, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -963,6 +1001,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, rd, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -987,6 +1027,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, rd, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -1009,6 +1051,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, rd, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -1032,6 +1076,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, rd, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -1054,6 +1100,8 @@ void batched_kalman_loop(raft::handle_t& handle, intercept, d_mu, rd, + d_obs_inter, + d_obs_inter_fut, d_pred, d_loglike, n_diff, @@ -1135,6 +1183,7 @@ void _lyapunov_wrapper(raft::handle_t& handle, void _batched_kalman_filter(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_ys, + const double* d_exog, int nobs, const ARIMAOrder& order, const MLCommon::LinAlg::Batched::Matrix& Zb, @@ -1145,8 +1194,10 @@ void _batched_kalman_filter(raft::handle_t& handle, const double* d_sigma2, bool intercept, const double* d_mu, + const double* d_beta, int fc_steps, double* d_fc, + const double* d_exog_fut, double level, double* d_lower, double* d_upper) @@ -1161,6 +1212,61 @@ void _batched_kalman_filter(raft::handle_t& handle, int rd = order.rd(); int r = order.r(); + // Compute observation intercept (exogenous component). + // The observation intercept is a linear combination of the values of + // exogenous variables for this observation. + rmm::device_uvector obs_intercept(0, stream); + rmm::device_uvector obs_intercept_fut(0, stream); + if (order.n_exog > 0) { + obs_intercept.resize(nobs * batch_size, stream); + + double alpha = 1.0; + double beta = 0.0; + CUBLAS_CHECK(raft::linalg::cublasgemmStridedBatched(cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_N, + nobs, + 1, + order.n_exog, + &alpha, + d_exog, + nobs, + nobs * order.n_exog, + d_beta, + order.n_exog, + order.n_exog, + &beta, + obs_intercept.data(), + nobs, + nobs, + batch_size, + stream)); + + if (fc_steps > 0) { + obs_intercept_fut.resize(fc_steps * batch_size, stream); + + CUBLAS_CHECK(raft::linalg::cublasgemmStridedBatched(cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_N, + fc_steps, + 1, + order.n_exog, + &alpha, + d_exog_fut, + fc_steps, + fc_steps * order.n_exog, + d_beta, + order.n_exog, + order.n_exog, + &beta, + obs_intercept_fut.data(), + fc_steps, + fc_steps, + batch_size, + stream)); + } + } + MLCommon::LinAlg::Batched::Matrix RQb( rd, 1, batch_size, cublasHandle, arima_mem.RQ_batches, arima_mem.RQ_dense, stream, true); double* d_RQ = RQb.raw_data(); @@ -1306,6 +1412,8 @@ void _batched_kalman_filter(raft::handle_t& handle, intercept, d_mu, order, + obs_intercept.data(), + obs_intercept_fut.data(), d_pred, d_loglike, fc_steps, @@ -1434,6 +1542,7 @@ void init_batched_kalman_matrices(raft::handle_t& handle, void batched_kalman_filter(raft::handle_t& handle, const ARIMAMemory& arima_mem, const double* d_ys, + const double* d_exog, int nobs, const ARIMAParams& params, const ARIMAOrder& order, @@ -1442,6 +1551,7 @@ void batched_kalman_filter(raft::handle_t& handle, double* d_pred, int fc_steps, double* d_fc, + const double* d_exog_fut, double level, double* d_lower, double* d_upper) @@ -1479,6 +1589,7 @@ void batched_kalman_filter(raft::handle_t& handle, _batched_kalman_filter(handle, arima_mem, d_ys, + d_exog, nobs, order, Zb, @@ -1489,8 +1600,10 @@ void batched_kalman_filter(raft::handle_t& handle, params.sigma2, static_cast(order.k), params.mu, + params.beta, fc_steps, d_fc, + d_exog_fut, level, d_lower, d_upper); @@ -1511,12 +1624,14 @@ void batched_jones_transform(raft::handle_t& handle, double* d_params = arima_mem.d_params; double* d_Tparams = arima_mem.d_Tparams; ARIMAParams params = {arima_mem.params_mu, + arima_mem.params_beta, arima_mem.params_ar, arima_mem.params_ma, arima_mem.params_sar, arima_mem.params_sma, arima_mem.params_sigma2}; - ARIMAParams Tparams = {arima_mem.Tparams_mu, + ARIMAParams Tparams = {params.mu, + params.beta, arima_mem.Tparams_ar, arima_mem.Tparams_ma, arima_mem.Tparams_sar, @@ -1528,7 +1643,6 @@ void batched_jones_transform(raft::handle_t& handle, params.unpack(order, batch_size, d_params, stream); MLCommon::TimeSeries::batched_jones_transform(order, batch_size, isInv, params, Tparams, stream); - Tparams.mu = params.mu; Tparams.pack(order, batch_size, d_Tparams, stream); diff --git a/cpp/src_prims/linalg/batched/matrix.cuh b/cpp/src_prims/linalg/batched/matrix.cuh index 446fcc4626..4d2aaf04e6 100644 --- a/cpp/src_prims/linalg/batched/matrix.cuh +++ b/cpp/src_prims/linalg/batched/matrix.cuh @@ -329,14 +329,14 @@ class Matrix { } //! Visualize the first matrix. - void print(std::string name) const + void print(std::string name, size_t ib = 0) const { - std::size_t len = m_shape.first * m_shape.second * m_batch_size; + std::size_t len = m_shape.first * m_shape.second; std::vector A(len); - raft::update_host(A.data(), raw_data(), len, m_stream); + raft::update_host(A.data(), raw_data() + ib * len, len, m_stream); std::cout << name << "=\n"; - for (int i = 0; i < m_shape.first; i++) { - for (int j = 0; j < m_shape.second; j++) { + for (size_t i = 0; i < m_shape.first; i++) { + for (size_t j = 0; j < m_shape.second; j++) { // column major std::cout << std::setprecision(10) << A[j * m_shape.first + i] << ","; } @@ -640,11 +640,13 @@ Matrix b_gemm(const Matrix& A, const Matrix& B, bool aT = false, bool b * - cuBLAS only supports overdetermined systems. * - This function copies A to avoid modifying the original one. * - * @param[in] A Batched matrix A (must have more rows than columns) - * @param[inout] C Batched matrix C (the number of rows must match A) + * @param[in] A Batched matrix A (must have more rows than columns) + * @param[inout] C Batched matrix C (the number of rows must match A) + * @param[out] infoArr (optional) Success indicator for each problem. + * See devInfoArray in cuBLAS documentation. */ template -void b_gels(const Matrix& A, Matrix& C) +void b_gels(const Matrix& A, Matrix& C, int* devInfoArray = nullptr) { ASSERT(A.batches() == C.batches(), "A and C must have the same number of batches"); auto m = A.shape().first; @@ -666,7 +668,7 @@ void b_gels(const Matrix& A, Matrix& C) C.data(), m, &info, - nullptr, + devInfoArray, A.batches(), A.stream())); } @@ -698,25 +700,40 @@ Matrix b_op_A(const Matrix& A, F unary_op) * * @param[in] A Batched matrix A * @param[in] B Batched matrix B + * @param[out] C Batched matrix C, result of A binary_op B * @param[in] binary_op The binary operation used on elements of A and B - * @return A batched matrix, the result of A binary_op B */ template -Matrix b_aA_op_B(const Matrix& A, const Matrix& B, F binary_op) +void b_aA_op_B(const Matrix& A, const Matrix& B, Matrix& C, F binary_op) { ASSERT(A.shape().first == B.shape().first && A.shape().second == B.shape().second, "ERROR: Matrices must be same size"); ASSERT(A.batches() == B.batches(), "A & B must have same number of batches"); - auto batch_size = A.batches(); - auto m = A.shape().first; - auto n = A.shape().second; + raft::linalg::binaryOp(C.raw_data(), + A.raw_data(), + B.raw_data(), + A.shape().first * A.shape().second * A.batches(), + binary_op, + A.stream()); +} - Matrix C(m, n, batch_size, A.cublasHandle(), A.stream()); +/** + * @brief A utility method to implement pointwise operations between elements + * of two batched matrices. + * + * @param[in] A Batched matrix A + * @param[in] B Batched matrix B + * @param[in] binary_op The binary operation used on elements of A and B + * @return A batched matrix, the result of A binary_op B + */ +template +Matrix b_aA_op_B(const Matrix& A, const Matrix& B, F binary_op) +{ + Matrix C(A.shape().first, A.shape().second, A.batches(), A.cublasHandle(), A.stream()); - raft::linalg::binaryOp( - C.raw_data(), A.raw_data(), B.raw_data(), m * n * batch_size, binary_op, A.stream()); + b_aA_op_B(A, B, C, binary_op); return C; } diff --git a/cpp/src_prims/timeSeries/arima_helpers.cuh b/cpp/src_prims/timeSeries/arima_helpers.cuh index 83f5ffba16..85a3c48973 100644 --- a/cpp/src_prims/timeSeries/arima_helpers.cuh +++ b/cpp/src_prims/timeSeries/arima_helpers.cuh @@ -27,8 +27,8 @@ #include #include "jones_transform.cuh" -namespace MLCommon { -namespace TimeSeries { +// Private helper functions and kernels in the anonymous namespace +namespace { /** * Auxiliary function of reduced_polynomial. Computes a coefficient of an (S)AR @@ -52,6 +52,128 @@ HDI DataT _param_to_poly(const DataT* param, int lags, int idx) return 1.0; } +/** + * @brief Helper function that will read in src0 if the given index is + * negative, src1 otherwise. + * @note This is useful when one array is the logical continuation of + * another and the index is expressed relatively to the second array. + * + * @param[in] src0 Data comes from here if the index is negative + * @param[in] size0 Size of src0 + * @param[in] src1 Data comes from here if the index is positive + * @param[in] idx Index, relative to the start of the second array src1 + * @return Data read from src0 or src1 according to the index + */ +template +DI DataT _select_read(const DataT* src0, int size0, const DataT* src1, int idx) +{ + return idx < 0 ? src0[size0 + idx] : src1[idx]; +} + +/** + * @brief Prepare future data with a simple or seasonal difference + * + * @param[in] in_past Input (past). Shape (n_past, batch_size) (device) + * @param[in] in_fut Input (future). Shape (n_fut, batch_size) (device) + * @param[out] out Output. Shape (n_fut, batch_size) (device) + * @param[in] n_past Number of past observations per series + * @param[in] n_fut Number of future observations per series + * @param[in] period Differencing period (1 or s) + * @param[in] stream CUDA stream + */ +template +__global__ void _future_diff_kernel( + const T* in_past, const T* in_fut, T* out, int n_past, int n_fut, int period = 1) +{ + const T* b_in_past = in_past + n_past * blockIdx.x; + const T* b_in_fut = in_fut + n_fut * blockIdx.x; + T* b_out = out + n_fut * blockIdx.x; + + for (int i = threadIdx.x; i < n_fut; i += blockDim.x) { + b_out[i] = b_in_fut[i] - _select_read(b_in_past, n_past, b_in_fut, i - period); + } +} + +/** + * @brief Prepare future data with two simple and/or seasonal differences + * + * @param[in] in_past Input (past). Shape (n_past, batch_size) (device) + * @param[in] in_fut Input (future). Shape (n_fut, batch_size) (device) + * @param[out] out Output. Shape (n_fut, batch_size) (device) + * @param[in] n_past Number of past observations per series + * @param[in] n_fut Number of future observations per series + * @param[in] period1 First differencing period (1 or s) + * @param[in] period2 Second differencing period (1 or s) + * @param[in] stream CUDA stream + */ +template +__global__ void _future_second_diff_kernel(const T* in_past, + const T* in_fut, + T* out, + int n_past, + int n_fut, + int period1 = 1, + int period2 = 1) +{ + const T* b_in_past = in_past + n_past * blockIdx.x; + const T* b_in_fut = in_fut + n_fut * blockIdx.x; + T* b_out = out + n_fut * blockIdx.x; + + for (int i = threadIdx.x; i < n_fut; i += blockDim.x) { + b_out[i] = b_in_fut[i] - _select_read(b_in_past, n_past, b_in_fut, i - period1) - + _select_read(b_in_past, n_past, b_in_fut, i - period2) + + _select_read(b_in_past, n_past, b_in_fut, i - period1 - period2); + } +} + +/** + * @brief Kernel to undifference the data with up to two levels of simple + * and/or seasonal differencing. + * @note One thread per series. + * + * @tparam double_diff true for two differences, false for one + * @tparam DataT Data type + * @param[inout] d_fc Forecasts, modified in-place + * @param[in] d_in Past observations + * @param[in] num_steps Number of forecast steps + * @param[in] batch_size Batch size + * @param[in] in_ld Leading dimension of d_in + * @param[in] n_in Number of past observations + * @param[in] s0 1st differencing period + * @param[in] s1 2nd differencing period if relevant + */ +template +__global__ void _undiff_kernel(DataT* d_fc, + const DataT* d_in, + int num_steps, + int batch_size, + int in_ld, + int n_in, + int s0, + int s1 = 0) +{ + int bid = blockIdx.x * blockDim.x + threadIdx.x; + if (bid < batch_size) { + DataT* b_fc = d_fc + bid * num_steps; + const DataT* b_in = d_in + bid * in_ld; + for (int i = 0; i < num_steps; i++) { + if (!double_diff) { // One simple or seasonal difference + b_fc[i] += _select_read(b_in, n_in, b_fc, i - s0); + } else { // Two differences (simple, seasonal or both) + DataT fc_acc = -_select_read(b_in, n_in, b_fc, i - s0 - s1); + fc_acc += _select_read(b_in, n_in, b_fc, i - s0); + fc_acc += _select_read(b_in, n_in, b_fc, i - s1); + b_fc[i] += fc_acc; + } + } + } +} + +} // namespace + +namespace MLCommon { +namespace TimeSeries { + /** * Helper function to compute the reduced AR or MA polynomial based on the * AR and SAR or MA and SMA parameters @@ -80,7 +202,6 @@ HDI DataT reduced_polynomial( /** * @brief Prepare data by differencing if needed (simple and/or seasonal) - * and removing a trend if needed * * @note: It is assumed that d + D <= 2. This is enforced on the Python side * @@ -128,51 +249,68 @@ void prepare_data(DataT* d_out, } /** - * @brief Helper function that will read in src0 if the given index is - * negative, src1 otherwise. - * @note This is useful when one array is the logical continuation of - * another and the index is expressed relatively to the second array. + * @brief Prepare future data by differencing if needed (simple and/or seasonal) + * + * This is a variant of prepare_data that produces an output of the same dimension + * as the input, using an other array of past data for the observations at the start + * + * @note: It is assumed that d + D <= 2. This is enforced on the Python side + * + * @param[out] d_out Output. Shape (n_fut, batch_size) (device) + * @param[in] d_in_past Input (past). Shape (n_past, batch_size) (device) + * @param[in] d_in_fut Input (future). Shape (n_fut, batch_size) (device) + * @param[in] batch_size Number of series per batch + * @param[in] n_past Number of past observations per series + * @param[in] n_fut Number of future observations per series + * @param[in] d Order of simple differences (0, 1 or 2) + * @param[in] D Order of seasonal differences (0, 1 or 2) + * @param[in] s Seasonal period if D > 0 + * @param[in] stream CUDA stream */ template -DI DataT _select_read(const DataT* src0, int size0, const DataT* src1, int idx) -{ - return idx < 0 ? src0[size0 + idx] : src1[idx]; -} - -/** - * @brief Kernel to undifference the data with up to two levels of simple - * and/or seasonal differencing. - * @note One thread per series. - */ -template -__global__ void _undiff_kernel(DataT* d_fc, - const DataT* d_in, - int num_steps, - int batch_size, - int in_ld, - int n_in, - int s0, - int s1 = 0) +void prepare_future_data(DataT* d_out, + const DataT* d_in_past, + const DataT* d_in_fut, + int batch_size, + int n_past, + int n_fut, + int d, + int D, + int s, + cudaStream_t stream) { - int bid = blockIdx.x * blockDim.x + threadIdx.x; - if (bid < batch_size) { - DataT* b_fc = d_fc + bid * num_steps; - const DataT* b_in = d_in + bid * in_ld; - for (int i = 0; i < num_steps; i++) { - if (!double_diff) { // One simple or seasonal difference - b_fc[i] += _select_read(b_in, n_in, b_fc, i - s0); - } else { // Two differences (simple, seasonal or both) - DataT fc_acc = -_select_read(b_in, n_in, b_fc, i - s0 - s1); - fc_acc += _select_read(b_in, n_in, b_fc, i - s0); - fc_acc += _select_read(b_in, n_in, b_fc, i - s1); - b_fc[i] += fc_acc; - } - } + // Only one difference (simple or seasonal) + if (d + D == 1) { + int period = d ? 1 : s; + int tpb = n_fut > 128 ? 64 : 32; // quick heuristics + _future_diff_kernel<<>>( + d_in_past, d_in_fut, d_out, n_past, n_fut, period); + CUDA_CHECK(cudaPeekAtLastError()); + } + // Two differences (simple or seasonal or both) + else if (d + D == 2) { + int period1 = d ? 1 : s; + int period2 = d == 2 ? 1 : s; + int tpb = n_fut > 128 ? 64 : 32; + _future_second_diff_kernel<<>>( + d_in_past, d_in_fut, d_out, n_past, n_fut, period1, period2); + CUDA_CHECK(cudaPeekAtLastError()); } + // If no difference and the pointers are different, copy in to out + else if (d + D == 0 && d_in_fut != d_out) { + raft::copy(d_out, d_in_fut, n_fut * batch_size, stream); + } + // Other cases: no difference and the pointers are the same, nothing to do } /** - * @brief Finalizes a forecast by undifferencing + * @brief Finalizes a forecast by undifferencing. + * + * This is used when doing "simple differencing" for integrated models (d > 0 or D > 0), i.e the + * series are differenced prior to running the Kalman filter. Forecasts output by the Kalman filter + * are then for the differenced series and we need to couple this with past observations to compute + * forecasts for the non-differenced series. This is not needed when differencing is handled by the + * Kalman filter. * * @note: It is assumed that d + D <= 2. This is enforced on the Python side * diff --git a/notebooks/arima_demo.ipynb b/notebooks/arima_demo.ipynb index d230ee955d..6321cba4b6 100644 --- a/notebooks/arima_demo.ipynb +++ b/notebooks/arima_demo.ipynb @@ -427,6 +427,71 @@ "source": [ "Note that the model can't form predictions at the start where we padded with missing values. The first in-sample predictions will be equal to a constant value (0 in the absence of intercept)." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exogenous variables\n", + "\n", + "ARIMA also supports exogenous variables. As this model works with a batch, there are some limitations: each column of `endog` corresponds to a fixed number of columns of `exog`. Exogenous variables can't be shared by different series (they have to be duplicated if necessary), and all series must have the same number of exogenous variables. The shape of `exog` must be `(n_obs, batch_size * n_exog)`, and columns are grouped by corresponding batch id. For predictions and forecasts, values of the exogenous variables for future steps must be provided in an array of shape `(nsteps, batch_size * n_exog)`.\n", + "\n", + "Note that endogenous variables might contain missing observations but exogenous variables cannot.\n", + "\n", + "To illustrate this, we will again create a fake dataset from the one used above, adding some procedural variables." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nb = 4\n", + "\n", + "# Generate exogenous variables and coefficients\n", + "get_sine = lambda n, period: \\\n", + " np.sin(np.r_[:n] * 2 * np.pi / period + np.random.uniform(0, period))\n", + "np_exog = np.column_stack([get_sine(319, T)\n", + " for T in np.random.uniform(20, 100, 2 * nb)])\n", + "np_exog_coef = np.random.uniform(20, 200, 2 * nb)\n", + "\n", + "# Create dataframes for the past and future values\n", + "df_exog = cudf.DataFrame(np_exog[:279])\n", + "df_exog_fut = cudf.DataFrame(np_exog[279:])\n", + "\n", + "# Add linear combination of the exogenous variables to the endogenous\n", + "df_guests_exog = df_guests.copy()\n", + "for ib in range(nb):\n", + " df_guests_exog[df_guests_exog.columns[ib]] += \\\n", + " np.matmul(np_exog[:279, ib*2:(ib+1)*2], np_exog_coef[ib*2:(ib+1)*2])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create and fit an ARIMA(1,0,1)(1,1,1)12 (c) model with exogenous variables\n", + "model_guests_exog = ARIMA(endog=df_guests_exog, exog=df_exog,\n", + " order=(1,0,1), seasonal_order=(1,1,1,12),\n", + " fit_intercept=True)\n", + "model_guests_exog.fit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Forecast\n", + "fc_guests_exog = model_guests_exog.forecast(40, exog=df_exog_fut)\n", + "\n", + "# Visualize after the time step 100\n", + "visualize(df_guests_exog[100:], fc_guests_exog)" + ] } ], "metadata": { @@ -446,7 +511,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.12" }, "mimetype": "text/x-python", "name": "python", diff --git a/python/cuml/datasets/arima.pyx b/python/cuml/datasets/arima.pyx index 8aff020038..d938b1ae3f 100644 --- a/python/cuml/datasets/arima.pyx +++ b/python/cuml/datasets/arima.pyx @@ -112,6 +112,7 @@ def make_arima(batch_size=1000, n_obs=100, order=(1, 1, 1), cpp_order.p, cpp_order.d, cpp_order.q = order cpp_order.P, cpp_order.D, cpp_order.Q, cpp_order.s = seasonal_order cpp_order.k = intercept + cpp_order.n_exog = 0 # Set the default output type to "cupy". This will be ignored if the user # has set `cuml.global_settings.output_type`. Only necessary for array diff --git a/python/cuml/test/test_arima.py b/python/cuml/test/test_arima.py index c6246cf87b..57f0f2a365 100644 --- a/python/cuml/test/test_arima.py +++ b/python/cuml/test/test_arima.py @@ -55,14 +55,18 @@ class ARIMAData: """Contains a dataset name and associated metadata """ + def __init__(self, batch_size, n_obs, n_test, dataset, - tolerance_integration): + tolerance_integration, n_exog=0, dataset_exog=None): self.batch_size = batch_size self.n_obs = n_obs self.n_test = n_test self.dataset = dataset self.tolerance_integration = tolerance_integration + self.n_exog = n_exog + self.dataset_exog = dataset_exog + self.n_train = n_obs - n_test @@ -111,6 +115,17 @@ def __init__(self, batch_size, n_obs, n_test, dataset, tolerance_integration=0.05 ) +# ARIMA(0,1,1) with intercept (exogenous variables) +test_011c_exog = ARIMAData( + batch_size=16, + n_obs=28, + n_test=2, + dataset="endog_deaths_by_region_exog", + tolerance_integration=0.05, + n_exog=2, + dataset_exog="exog_deaths_by_region_exog" +) + # ARIMA(1,2,1) with intercept test_121c = ARIMAData( batch_size=2, @@ -165,6 +180,18 @@ def __init__(self, batch_size, n_obs, n_test, dataset, tolerance_integration=0.01 ) +# ARIMA(1,1,1)(2,0,0)_4 with intercept +# (missing observations and exogenous variables) +test_111_200_4c_missing_exog = ARIMAData( + batch_size=14, + n_obs=123, + n_test=10, + dataset="endog_hourly_earnings_by_industry_missing_exog", + tolerance_integration=0.01, + n_exog=2, + dataset_exog="exog_hourly_earnings_by_industry_missing_exog", +) + # ARIMA(1,1,2)(0,1,2)_4 test_112_012_4 = ARIMAData( batch_size=2, @@ -192,6 +219,17 @@ def __init__(self, batch_size, n_obs, n_test, dataset, tolerance_integration=0.03 ) +# ARIMA(1,1,1)(1,1,1)_12 (missing obs, exogenous variables, intercept) +test_111_111_12c_missing_exog = ARIMAData( + batch_size=12, + n_obs=279, + n_test=20, + dataset="endog_guest_nights_by_region_missing_exog", + tolerance_integration=0.001, + n_exog=2, + dataset_exog="exog_guest_nights_by_region_missing_exog", +) + # Dictionary matching a test case to a tuple of model parameters # (a test case could be used with different models) # (p, d, q, P, D, Q, s, k) -> ARIMAData @@ -201,15 +239,18 @@ def __init__(self, batch_size, n_obs, n_test, dataset, ((0, 1, 0, 0, 0, 0, 0, 1), test_010c), ((1, 1, 0, 0, 0, 0, 0, 0), test_110), ((0, 1, 1, 0, 0, 0, 0, 1), test_011c), + ((0, 1, 1, 0, 0, 0, 0, 1), test_011c_exog), ((1, 2, 1, 0, 0, 0, 0, 1), test_121c), ((1, 1, 1, 0, 0, 0, 0, 1), test_111c_missing), ((1, 0, 1, 1, 1, 1, 4, 0), test_101_111_4), ((5, 1, 0, 0, 0, 0, 0, 0), test_510), ((1, 1, 1, 2, 0, 0, 4, 1), test_111_200_4c), ((1, 1, 1, 2, 0, 0, 4, 1), test_111_200_4c_missing), + ((1, 1, 1, 2, 0, 0, 4, 1), test_111_200_4c_missing_exog), ((1, 1, 2, 0, 1, 2, 4, 0), test_112_012_4), stress_param((1, 1, 1, 1, 1, 1, 12, 0), test_111_111_12), stress_param((1, 1, 1, 1, 1, 1, 12, 0), test_111_111_12_missing), + stress_param((1, 0, 1, 1, 1, 1, 12, 1), test_111_111_12c_missing_exog), ] # Dictionary for lazy-loading of datasets @@ -243,7 +284,19 @@ def get_dataset(data, dtype): shuffle=False) y_train_cudf = cudf.from_pandas(y_train).fillna(np.nan) y_test_cudf = cudf.from_pandas(y_test) - lazy_data[key] = (y_train, y_train_cudf, y_test, y_test_cudf) + if data.dataset_exog is not None: + exog = pd.read_csv( + os.path.join(data_path, "{}.csv".format(data.dataset_exog)), + usecols=range(1, data.n_exog * data.batch_size + 1), + dtype=dtype) + exog_past, exog_fut = train_test_split(exog, test_size=data.n_test, + shuffle=False) + exog_past_cudf = cudf.from_pandas(exog_past).fillna(np.nan) + exog_fut_cudf = cudf.from_pandas(exog_fut) + else: + exog_past, exog_past_cudf, exog_fut, exog_fut_cudf = [None]*4 + lazy_data[key] = (y_train, y_train_cudf, y_test, y_test_cudf, + exog_past, exog_past_cudf, exog_fut, exog_fut_cudf) return lazy_data[key] @@ -251,14 +304,20 @@ def get_ref_fit(data, order, seasonal_order, intercept, dtype): """Compute a reference fit of a dataset with the given parameters and dtype or return a previously computed fit """ - y_train, *_ = get_dataset(data, dtype) + y_train, _, _, _, exog_past, *_ = get_dataset(data, dtype) key = order + seasonal_order + \ (intercept, data.dataset, np.dtype(dtype).name) + batch_size = y_train.shape[1] if key not in lazy_ref_fit: - ref_model = [sm.tsa.SARIMAX(y_train[col], order=order, - seasonal_order=seasonal_order, - trend='c' if intercept else 'n') - for col in y_train.columns] + ref_model = [ + sm.tsa.SARIMAX(endog=y_train[y_train.columns[i]], + exog=exog_past[exog_past.columns[ + data.n_exog*i:data.n_exog*(i+1)]] + if exog_past is not None else None, + order=order, + seasonal_order=seasonal_order, + trend='c' if intercept else 'n') + for i in range(batch_size)] with warnings.catch_warnings(): warnings.filterwarnings("ignore") lazy_ref_fit[key] = [model.fit(disp=0) for model in ref_model] @@ -321,13 +380,15 @@ def test_integration(key, data, dtype): order, seasonal_order, intercept = extract_order(key) s = max(1, seasonal_order[3]) - y_train, y_train_cudf, y_test, _ = get_dataset(data, dtype) + y_train, y_train_cudf, y_test, _, _, exog_past_cudf, exog_fut, \ + exog_fut_cudf = get_dataset(data, dtype) # Get fit reference model ref_fits = get_ref_fit(data, order, seasonal_order, intercept, dtype) # Create and fit cuML model - cuml_model = arima.ARIMA(y_train_cudf, + cuml_model = arima.ARIMA(endog=y_train_cudf, + exog=exog_past_cudf, order=order, seasonal_order=seasonal_order, fit_intercept=intercept, @@ -335,11 +396,14 @@ def test_integration(key, data, dtype): cuml_model.fit() # Predict - y_fc_cuml = cuml_model.forecast(data.n_test) + y_fc_cuml = cuml_model.forecast(data.n_test, exog=exog_fut) y_fc_ref = np.zeros((data.n_test, data.batch_size)) for i in range(data.batch_size): y_fc_ref[:, i] = ref_fits[i].get_prediction( - data.n_train, data.n_obs - 1).predicted_mean + data.n_train, data.n_obs - 1, + exog=None if data.n_exog == 0 else + exog_fut[exog_fut.columns[data.n_exog*i:data.n_exog*(i+1)]] + ).predicted_mean # Compare results: MASE must be better or within the tolerance margin mase_ref = mase(y_train, y_test, y_fc_ref, s) @@ -373,13 +437,15 @@ def _predict_common(key, data, dtype, start, end, num_steps=None, level=None, """ order, seasonal_order, intercept = extract_order(key) - _, y_train_cudf, *_ = get_dataset(data, dtype) + _, y_train_cudf, _, _, _, exog_cudf, exog_fut, exog_fut_cudf \ + = get_dataset(data, dtype) # Get fit reference model ref_fits = get_ref_fit(data, order, seasonal_order, intercept, dtype) # Create cuML model - cuml_model = arima.ARIMA(y_train_cudf, + cuml_model = arima.ARIMA(endog=y_train_cudf, + exog=exog_cudf, order=order, seasonal_order=seasonal_order, fit_intercept=intercept, @@ -395,26 +461,38 @@ def _predict_common(key, data, dtype, start, end, num_steps=None, level=None, ref_preds = np.zeros((end - start, data.batch_size)) for i in range(data.batch_size): ref_preds[:, i] = ref_fits[i].get_prediction( - start, end - 1).predicted_mean + start, end - 1, + exog=( + None if data.n_exog == 0 or end <= data.n_train else + exog_fut[exog_fut.columns[data.n_exog*i:data.n_exog*(i+1)]]) + ).predicted_mean if level is not None: ref_lower = np.zeros((end - start, data.batch_size)) ref_upper = np.zeros((end - start, data.batch_size)) for i in range(data.batch_size): - temp_pred = ref_fits[i].get_forecast(num_steps) + temp_pred = ref_fits[i].get_forecast( + num_steps, + exog=( + None if data.n_exog == 0 else exog_fut[ + exog_fut.columns[data.n_exog*i:data.n_exog*(i+1)]]) + ) ci = temp_pred.summary_frame(alpha=1-level) ref_lower[:, i] = ci["mean_ci_lower"].to_numpy() ref_upper[:, i] = ci["mean_ci_upper"].to_numpy() # cuML if num_steps is None: - cuml_pred = cuml_model.predict(start, end) + cuml_pred = cuml_model.predict( + start, end, + exog=None if data.n_exog == 0 or end <= data.n_train + else exog_fut_cudf) elif level is not None: cuml_pred, cuml_lower, cuml_upper = \ - cuml_model.forecast(num_steps, level) + cuml_model.forecast(num_steps, level, exog=exog_fut_cudf) else: - cuml_pred = cuml_model.forecast(num_steps) + cuml_pred = cuml_model.forecast(num_steps, exog=exog_fut_cudf) # Compare results - np.testing.assert_allclose(cuml_pred, ref_preds, rtol=0.001, atol=0.01) + np.testing.assert_allclose(cuml_pred, ref_preds, rtol=0.002, atol=0.01) if level is not None: np.testing.assert_allclose( cuml_lower, ref_lower, rtol=0.005, atol=0.01) @@ -475,13 +553,14 @@ def test_loglikelihood(key, data, dtype, simple_differencing): """ order, seasonal_order, intercept = extract_order(key) - _, y_train_cudf, *_ = get_dataset(data, dtype) + _, y_train_cudf, _, _, _, exog_past_cudf, *_ = get_dataset(data, dtype) # Get fit reference model ref_fits = get_ref_fit(data, order, seasonal_order, intercept, dtype) # Create cuML model - cuml_model = arima.ARIMA(y_train_cudf, + cuml_model = arima.ARIMA(endog=y_train_cudf, + exog=exog_past_cudf, order=order, seasonal_order=seasonal_order, fit_intercept=intercept, @@ -511,17 +590,19 @@ def test_gradient(key, data, dtype): order, seasonal_order, intercept = extract_order(key) p, _, q = order P, _, Q, _ = seasonal_order - N = p + P + q + Q + intercept + 1 h = 1e-8 - _, y_train_cudf, *_ = get_dataset(data, dtype) + _, y_train_cudf, _, _, _, exog_past_cudf, *_ = get_dataset(data, dtype) # Create cuML model - cuml_model = arima.ARIMA(y_train_cudf, + cuml_model = arima.ARIMA(endog=y_train_cudf, + exog=exog_past_cudf, order=order, seasonal_order=seasonal_order, fit_intercept=intercept) + N = cuml_model.complexity + # Get an estimate of the parameters and pack them into a vector cuml_model._estimate_x0() x = cuml_model.pack() @@ -533,7 +614,10 @@ def test_gradient(key, data, dtype): scipy_grad = np.zeros(N * data.batch_size) for i in range(data.batch_size): # Create a model with only the current series - model_i = arima.ARIMA(y_train_cudf[y_train_cudf.columns[i]], + model_i = arima.ARIMA(endog=y_train_cudf[y_train_cudf.columns[i]], + exog=None if exog_past_cudf is None + else exog_past_cudf[exog_past_cudf.columns[ + data.n_exog*i:data.n_exog*(i+1)]], order=order, seasonal_order=seasonal_order, fit_intercept=intercept) @@ -555,20 +639,30 @@ def test_start_params(key, data, dtype): """ order, seasonal_order, intercept = extract_order(key) - y_train, y_train_cudf, *_ = get_dataset(data, dtype) + y_train, y_train_cudf, _, _, exog_past, exog_past_cudf, *_ \ + = get_dataset(data, dtype) # fillna for reference to match cuML initial estimation strategy y_train_nona = fill_interpolation(y_train) + # Convert to numpy to avoid misaligned indices + if exog_past is not None: + exog_past_np = exog_past.to_numpy() + # Create models - cuml_model = arima.ARIMA(y_train_cudf, + cuml_model = arima.ARIMA(endog=y_train_cudf, + exog=exog_past_cudf, order=order, seasonal_order=seasonal_order, fit_intercept=intercept) - ref_model = [sm.tsa.SARIMAX(y_train_nona[col], order=order, + ref_model = [sm.tsa.SARIMAX(endog=y_train_nona[y_train_nona.columns[i]], + exog=exog_past_np[ + :, i*data.n_exog:(i+1)*data.n_exog] + if data.n_exog else None, + order=order, seasonal_order=seasonal_order, trend='c' if intercept else 'n') - for col in y_train_nona.columns] + for i in range(data.batch_size)] # Estimate reference starting parameters N = cuml_model.complexity diff --git a/python/cuml/test/ts_datasets/README.md b/python/cuml/test/ts_datasets/README.md index 6c1bbe3dba..03f17ff8b1 100644 --- a/python/cuml/test/ts_datasets/README.md +++ b/python/cuml/test/ts_datasets/README.md @@ -19,7 +19,15 @@ This folder contains various datasets to test our time series analysis. Using da - `police_recorded_crime.csv`: Recorded crimes (units) per year, 1878-2014. - `population_estimate.csv`: Population estimates (thousands) per year, 1875-2011. -The following files are derived from the Stats NZ dataset by removing observations (to test support for missing observations): +The following files are derived from the Stats NZ dataset by removing observations (to test support for missing observations) and/or adding procedural exogenous variables: - `guest_nights_by_region_missing.csv` - `hourly_earnings_by_industry_missing.csv` -- `population_estimate_missing.csv` \ No newline at end of file +- `population_estimate_missing.csv` +- `endog_deaths_by_region_exog.csv` +- `endog_guest_nights_by_region_missing_exog.csv` +- `endog_hourly_earnings_by_industry_missing_exog.csv` + +The following files represent procedural exogenous variables linked to the series above (normalized): +- `exog_deaths_by_region_exog.csv` +- `exog_guest_nights_by_region_missing_exog.csv` +- `exog_hourly_earnings_by_industry_missing_exog.csv` diff --git a/python/cuml/test/ts_datasets/endog_deaths_by_region_exog.csv b/python/cuml/test/ts_datasets/endog_deaths_by_region_exog.csv new file mode 100644 index 0000000000..bc8159e281 --- /dev/null +++ b/python/cuml/test/ts_datasets/endog_deaths_by_region_exog.csv @@ -0,0 +1,29 @@ +,Northland Region,Auckland Region,Waikato Region,Bay of Plenty Region,Gisborne Region,Hawke's Bay Region,Taranaki Region,Manawatu-Wanganui Region,Wellington Region,Tasman Region,Nelson Region,Marlborough Region,West Coast Region,Canterbury Region,Otago Region,Southland Region +0,996.0,6768.0,2256.0,1647.0,369.0,1206.0,888.0,1950.0,2940.0,237.0,360.0,270.0,318.0,3756.0,1581.0,840.0 +1,1038.9504083744005,6952.438043167127,2485.113767748108,1779.3883525138863,431.89560211209425,1304.2642215271565,1036.5471247093506,2036.1456081473514,2992.806902426471,322.62946335095785,406.0179425670709,379.3377726268739,390.8048040628316,4075.498902651059,1675.393568860146,943.6548564339623 +2,1093.1087848206537,7152.236292341489,2381.0794761244383,1822.0386957877217,517.0666181419041,1342.2298786056208,1015.5803915429468,2024.0105974570256,3040.2493505739303,352.8670950715875,405.17438599581175,362.5140453524556,312.0184771435616,3876.6108087872935,1540.0350027692327,766.4226405747478 +3,1154.7146732241865,7094.990758465727,2481.7197091484513,1808.0868946208034,488.6788274055916,1303.8482950750497,963.9808649133096,1977.176739577192,2770.0914126430303,255.75161262011798,309.1900575815597,237.16924633706643,180.62191595059778,3846.4952002054506,1404.692164784402,745.8836579333731 +4,1214.0695102695297,7294.747998600694,2429.6450943731224,1985.7129372964323,519.2717371454779,1215.86862622914,892.4086481348347,1921.9996244089268,2903.9589662519315,193.87470199030548,275.6396749978092,296.5823507686031,199.91266046871,3956.056066142721,1537.3997547090307,916.3648366687227 +5,1180.5644749552614,7507.17961116537,2571.789410268783,2039.0762076777257,421.26877466894024,1163.1339412479751,806.377328566198,1894.3115741565052,2885.4065922620916,226.98960416703216,342.128620031849,401.4217692345546,329.9978151918066,4081.547017630741,1676.116867442315,933.2905882059263 +6,1224.7047603660887,7407.700898093777,2489.5779191585725,1945.7867927267828,375.336084306059,1129.8971395789656,670.2526328328049,1739.6010543520192,2940.017273854194,286.2065004202234,450.2509902070052,468.74605904838813,345.0188771927413,3834.7082724680963,1548.184573213674,697.9427835729385 +7,1160.1293000631065,6933.523094171837,2486.7023496362913,1753.0422678053,275.8619061667761,1014.6519083612786,684.553945548555,1805.3718103406559,2807.3172874916468,326.3779148013999,427.1581286012663,380.78351483101744,230.17090452270594,3523.276251937407,1393.1356869768356,713.6049164973049 +8,1271.6251626633054,7276.634498063058,2559.5380537951855,1976.886139696935,341.8343772971416,1164.6129586398672,911.1470454402347,2040.8701238617393,3007.2142562277872,309.5615380476903,361.0795936404886,297.0729857726715,128.1608755967526,3834.337940125184,1565.6406509581166,859.8231845953732 +9,1166.1360407418688,6887.712190532319,2378.2953573257014,1868.2591761146844,313.6057441977176,1131.2634652789216,944.6704269594989,1988.3052357459355,2798.1500840353056,222.82495672579628,287.23268272648573,337.6392283744881,161.70078429798355,3699.1230785710427,1561.4916546588859,836.0694358115003 +10,1260.7644956315119,7112.972225815915,2426.0390856416366,2047.5655590683461,354.3719979060145,1220.123187905753,993.0509808563199,2021.3410203702056,2812.1959283678743,171.08825279307234,329.4830479034397,462.8362445342101,297.6590462093425,3870.522823922246,1493.5717538934468,725.7440046991989 +11,1324.767868625797,7177.971678096597,2353.7310613007207,2019.3497526310869,434.89903884983556,1263.0838042249616,933.418391102556,1858.1884996531007,2774.258822747422,201.63125433949682,494.1617251445538,553.4786913641619,215.16997737841118,3879.064909660275,1412.9138443241363,832.4722266944501 +12,1274.5480215938476,7178.380225682429,2347.446741323331,2056.408034021699,495.5430436560038,1313.5693244494769,889.3055093104801,1795.6613832085868,2825.26035554813,308.31207960849355,511.33418529910756,421.5588142517996,85.36126940967603,3764.8941232106736,1599.5659955539336,903.8397946788524 +13,1191.6353160157732,7236.74167659451,2455.889670152143,2217.3073813038395,468.4619465811397,1143.8617226789715,819.7670561436108,1876.3695650079355,2907.594154976519,362.29065191004685,441.3578068444726,384.853001833573,93.8945215911836,4210.296238850455,1553.5502346356634,800.4852205108929 +14,1145.6674691441117,6854.246737772542,2347.280405454169,2153.939841020403,466.49885762509797,1127.1464263222308,794.4434457564427,1822.7431453065556,2783.297990809023,296.43401850197,309.639772226726,409.26407495153603,202.83399690261112,3868.1572197227533,1352.3513514717915,646.65741086763 +15,1119.3641302294022,7127.537325748383,2608.6374529387517,2386.486977479464,485.6821701130363,1040.667216053644,786.5154424909291,2027.572200424207,2825.9258432009287,218.45237593530175,348.11178487805716,525.2832982492791,207.5630298133887,3931.52360861375,1328.9852921392676,823.2314662329486 +16,1155.4981903755518,7196.559899132302,2673.4058567609363,2389.0686031177943,367.4689982391534,1052.561471939067,943.1352324205246,1889.5724526085758,2766.7665314643914,173.2551877252159,476.35603640756267,570.5907887701976,91.40172871162537,3980.4326276497923,1534.0233963234643,894.7413123845197 +17,1251.8649698105212,7394.480914284106,2787.3342049069906,2481.427288371271,324.3455623746878,999.5079864338826,977.7119470154184,1939.4936494513122,2845.0769178886235,222.7484622804227,575.8178675101883,488.82955990522026,44.237283513367856,4014.645703243325,1661.3428516169931,752.738647610733 +18,1172.2505109786134,7396.671956262273,2698.4615896952882,2437.236224286564,269.67076289998363,1107.526455183087,972.1730588040948,1872.3903106830073,2801.7777463450925,318.5244659573674,466.77766063176375,404.983617720718,131.39776640428724,4261.126864672212,1463.6728239058586,718.9546802137188 +19,1097.4002415087211,7264.765873259101,2720.0587040307646,2310.955062369673,297.2628881127559,1117.963579484021,935.0586656539452,1784.0389259992926,2848.5239959342366,331.3815724207062,372.11716983219355,450.7210360106143,201.4012676035091,4206.863111803736,1258.2910579759787,849.6119213422157 +20,1210.9892563670944,7761.778900199965,2879.374495007197,2336.513861980383,320.9698036517707,1204.9097916372352,925.4399807214569,1812.1596190148923,3127.4906039863854,277.28941596724496,360.722843206304,593.8949267510523,163.28720011049774,4329.1849174037825,1545.866839244076,889.6757763863114 +21,1204.595403951423,7769.287862569286,2852.0704946709297,2321.392195725619,408.42075494651857,1239.266191985429,834.2851246492112,1900.113529657344,3115.521446712909,179.16453745913148,503.8155206298622,622.3538258477845,-0.7990084773221895,4320.134929372896,1662.9370051285464,723.2892159796717 +22,1183.6762490987244,7704.646631373182,2805.274256668521,2345.8085497211814,467.6971564677179,1119.8457973490933,766.9688938349474,2002.2515518119765,2927.6305558759445,214.86642891651843,590.3372567969327,517.6089936849814,62.94477896310872,4148.478073805102,1464.6782433104709,783.2770984017143 +23,1187.5508314192239,8204.222493961657,2797.2434623066124,2481.708151560875,491.2907267139968,1151.1692727595096,892.2082970389683,2132.8123676960804,3001.5534503331487,275.3487292844441,552.1604948596505,441.7354408921744,171.59306018970088,4375.985039595521,1551.7553454328274,948.7926638195188 +24,1249.3869461938807,8371.631326408744,2899.6937262915326,2690.0418758831406,457.91829219710473,1182.4621639231473,1013.8078677475456,2025.635061595966,2955.2774584956846,386.63796572481215,488.35788518630227,557.1416440141217,190.01257098829717,4151.807810431371,1643.9586786383747,918.7903350641574 +25,1369.1944549060454,8222.950560202296,2815.8962268421365,2632.51277660153,453.8126533304794,1066.4987195378694,1042.582415557717,2003.605478369024,3115.259923632251,366.6421867858665,443.5476868494218,688.9358944834481,30.67750965002341,4081.9552198026895,1546.4165299478184,792.6007344059269 +26,1420.8248920952383,8803.890919563275,3091.687624539751,2963.60746867278,453.9905893173028,1254.4204906474442,1078.9789295120765,2132.444290987613,3263.4107476356694,249.2329147767512,545.7396291994676,655.242920834494,11.423585216247943,4531.864042385398,1511.5003594921118,909.2964867254242 +27,1508.9773842065492,8814.911588611316,3110.5485819589103,2918.415006298744,393.4043802201025,1299.788494472141,1032.0359708892447,2083.9709054764917,3308.4522993738406,208.78692998453175,627.2548657301826,586.8865382911815,92.08714842685964,4403.649025432484,1486.2865672704493,938.2290317900179 diff --git a/python/cuml/test/ts_datasets/endog_guest_nights_by_region_missing_exog.csv b/python/cuml/test/ts_datasets/endog_guest_nights_by_region_missing_exog.csv new file mode 100644 index 0000000000..985e788e97 --- /dev/null +++ b/python/cuml/test/ts_datasets/endog_guest_nights_by_region_missing_exog.csv @@ -0,0 +1,280 @@ +,Northland,Auckland,Waikato,Bay of Plenty,"Hawke's Bay, Gisborne","Taranaki, Manawatu, Wanganui",Wellington,"Nelson, Marlborough, Tasman",Canterbury,West Coast,Otago,Southland +0,64.17299227679365,256.12957833548916,122.90063742259679,153.34482929479051,46.46812303813159,92.4040361649103,110.08256257512701,56.66954966645217,210.85669076131327,27.326230187754817,181.95421948782865,28.68977607448827 +1,60.77705041928116,278.06747055665863,119.22263536632272,144.2406540652008,64.26100497982999,105.20464266565688,,82.85207011198217,226.33968447296206,46.95904341058122,245.33139736074503,53.20952051806148 +2,,286.20150559049614,155.30774275799837,184.4523636262992,97.46749769480458,138.90042530860737,153.20324761889702,122.73479246507993,281.71789223943426,78.30778562298264,272.93376361830417, +3,120.90289466573269,309.86477242189716,164.1905893440761,198.39837519061413,115.25490956553529,145.85539965859192,174.371983200797,145.15648613560853,299.63612447645215,102.8307302306984,270.03897824474285,116.42293046019745 +4,,371.2532515091942,181.25369073799166,235.0857238307511,133.7524546453504,165.7603696753859,208.93896611624365,,,130.0216247572543,300.0633662954886,149.90321222941935 +5,228.23618116065984,385.44661104346545,,311.4773236923793,166.86200467142518,176.774274476457,218.39681709535606,260.508204364607,434.43537375102846,146.64284052269358,367.8470975961556,177.30376402709825 +6,325.21336380889073,441.76290051229415,376.55129321034417,,247.53173885280978,227.3353634610168,255.65673062494136,378.01108559779163,520.4623297563426,,479.56740405219574,192.4104156000344 +7,204.61183771478898,423.87418945004424,245.17689691980723,301.63074835262637,195.48581951083403,206.2278197417678,263.96313312914015,268.46102036047887,449.1025659543392,169.110591265166,403.8067149093254,199.05287386010514 +8,204.30811291922188,,274.8917568575331,317.0811977835216,199.03009334690532,216.639497028831,271.93460598197373,258.87921633215194,433.7038963437551,162.2669520553419,396.5938977400008,180.90230611398374 +9,161.87620905862946,373.59152437080127,201.9007925767581,273.9140875882694,182.82618442819847,192.8776036681976,247.1140389232352,212.37629648619443,368.0056295337654,127.01037125154744,315.3105160221565,143.50893733609004 +10,126.60232102174227,336.21910792442134,173.13405275843178,225.68457106122156,165.09092219924545,179.78572429471768,235.16480582693606,177.2491930292626,282.57705299769896,95.98643834926476,239.69498629776925,103.89914563993854 +11,111.06879773889949,308.21572617159416,167.89809977612998,201.6239570075229,158.04917651712594,,217.30342787559943,157.8674643416101,242.88513101196529,81.2658524947087,,80.04913176988083 +12,,352.80875130485225,210.60663848803125,252.07100351202922,162.05775637808352,195.79981806951434,226.1844818570277,151.83110134708465,,62.81503689498883,243.21795070013076, +13,131.06657856157304,363.9099810242821,191.72030784758442,232.36619806938404,157.78352257318485,200.9458649977767,206.4563316660434,142.28504238229226,250.97360787994955,36.07834296563914,257.88666292010265,26.966873597779824 +14,143.87771315469763,350.8848098087916,,247.87286439150782,156.9976139987921,199.15193184220135,198.61159835765162,129.14177682451577,240.23969846284785,25.421446120095705,,12.523742692497471 +15,180.40585538255988,,255.79859476493854,297.1626091037754,179.8581829103668,204.45167186150525,203.54859712432255,,298.8972828348238,28.408911704079404,194.03251726764415,11.259828139577472 +16,181.3579985668115,419.9449491304847,234.03596348581533,275.42931928955215,,,190.87932475074587,131.97718404821072,290.8815240142501,10.860336039343636,183.23851122378505,3.239379471185906 +17,297.3384495575763,428.2665411205504,331.66440891143736,337.3786006821596,,152.29879715318924,163.7935918519268,198.43252391547782,,0.5844039412331199,195.82891900236592, +18,465.7102413733304,472.1781997374792,457.9824047286671,506.2440883761818,236.1687534259462,191.5738866968739,164.74804471502333,,399.36368160634885,-2.1679120586377536,308.147980587583,8.193393577783723 +19,250.40352240123477,482.52650635174604,300.9388226712018,321.9354203546028,149.36199283047924,140.14301118710026,,,302.5650288743075,-23.341912590934967,193.51227729156452,-5.109028783545412 +20,225.80916149188334,471.1898725051267,268.5413313115603,,132.8450775501953,122.59481550059678,149.5033587335011,123.95983653483941,244.249728340452,,178.37638139598306,-12.852109339523224 +21,225.6124568166897,430.10771447384957,265.81012870178836,270.407689193587,126.83142011547359,103.58210059799168,126.85194635310965,89.40201739805157,210.30058061273255,-53.95532619358271,173.0994080907351,-9.355869325287884 +22,162.80261799104255,389.47500473238216,202.43455558967486,,86.6006391184157,61.34834112876805,97.17986374245346,33.17384847213427,90.45021298406557,-81.66265007966079,62.44601757166957, +23,139.4667759854316,358.04714760162653,180.95405616836277,140.84343766479742,,35.92270010806437,65.16580308472683,-0.08583659490479789,31.25439851386527,-96.79193901485021,,-14.056016201175382 +24,,393.4851814366358,198.84964530287294,190.12961018927265,67.03536052454243,,94.86265447563581,-0.9630516003743423,90.04771739076827,-77.39470419912175,145.1038190063706,7.099633305599935 +25,144.63613875307706,370.68886656493646,171.05245871580846,148.21861970203636,51.57627963666792,21.44254561022032,69.68619806023389,-15.122879587808328,87.1500962880283,-66.64145006471051,196.9710961578623,39.980577235020796 +26,146.9746187046207,373.0090949628321,192.14279977074358,154.5645506630272,42.82657124947198,25.715268946743393,65.46600007539678,-4.583333683781973,104.68939934178191,,206.73826943981575,79.37141111686914 +27,177.6427030076839,417.8901057744501,213.03346560633312,197.08791053898153,54.93534159061152,38.90548627400287,87.72212970730871,14.785417938360865,203.99983553965808,3.7713218212170148,,117.81856779552692 +28,184.04599761782114,436.77628949499683,177.05340008297966,155.20086427497034,25.47367600198011,,,31.735622126847062,229.75971837874548,28.14753166727553,265.62610568888584,148.98675367590488 +29,253.52553800779958,407.6244380860303,274.2087691466604,217.4710595827851,38.63634256995729,2.410631779727524,78.45674183491622,120.24512974632302,281.4525579868648,65.039922090278,334.75122862248406, +30,338.5517949131872,457.53409421110484,392.5609497850607,388.7919358654618,94.44522098570393,31.216667881782314,102.1254188010646,274.3656915848173,412.48032321197707,115.73440957348835,512.384294686282,204.57881443826506 +31,198.90228373463884,453.6532979817648,202.96933521887738,148.10767647105018,27.697703261821893,3.7202906388561843,93.06936404299063,149.1633026029337,334.5922031161586,122.71163571275935,,206.23130715493096 +32,180.06980231728102,465.474800429703,196.4274047978714,108.83206366448316,17.484008787162495,-5.331468053860135,112.98730583358682,142.47411473295034,323.8079293251147,121.93993934894277,379.7632690535629,194.90025526395942 +33,161.8055178199014,389.45431380733885,166.665616941418,145.78322500974141,17.223826307162632,-3.7297270335227495,116.11248552018034,139.53723660212967,322.59599735575637,127.34764541984913,354.8790702616999,171.8841852179261 +34,99.04106313823357,331.0905422210189,82.13429248039212,,-22.952909476563704,-39.8372008589518,89.8787132517622,100.71318017900143,219.48547507762603,113.85250236446008,260.2813871891552,130.57731001000872 +35,94.59120886882394,300.9231147745463,95.39691540939094,32.133461042671996,-14.546892343941508,-32.89152468285715,104.73427285052331,122.27716632763568,226.8481626392898,111.55127832491553,,106.26740078818844 +36,89.46904996036085,340.58369300231544,88.57443858835406,,-15.08039491651411,-17.09967839366601,132.48439649205466,132.59666340470758,266.9921522622426,116.4546028534119,345.54116570122505,76.49843559223382 +37,78.7761874004413,305.6489706843157,57.74386076779937,2.37880589133556,-23.924914783836357,,121.94738685694885,143.41018749327014,,,358.9002421045252,44.29209051187635 +38,87.7982752569694,347.2320250567426,86.14167928134773,52.96354412914644,,48.17775070751357,168.21185609027472,188.42803621340437,348.21272741486223,128.96548331498562,352.2251275667646,28.925069187085917 +39,,358.27961605455937,101.20873923350788,,-2.2026134676020206,26.14050147725409,172.92642864925762,202.28149772042076,379.40636600939825,123.70296011548663,297.4406185453771,16.02492797351267 +40,102.25240184051947,412.59385615640014,85.6210783160023,54.72906960373308,-0.6576275091539117,35.17318329329908,188.86845114653366,252.60993032658718,409.96438840170606,127.86437920537657,308.5014558877492,14.110870312943561 +41,,,156.07483945495125,126.44854253324439,33.01366924528242,55.07318987863034,208.2614790589131,327.25797900640066,413.5303060549219,119.01425719152637,313.0385241145576,0.44387221343299643 +42,262.1510168954733,441.97605359597435,298.0524197758136,309.58888797406325,,126.5422880205808,247.57234111533467,468.17381461806923,492.8620749714096,132.4970529223614,409.43199855011545,-1.6273728912755985 +43,122.39255586970621,411.42796955447216,138.682285042299,90.95946322127693,,94.32047533888087,251.51278452735085,334.9346200792453,,102.22517309370332,321.5298195507169,0.4112710612371586 +44,103.98607847715762,,118.56089236625172,78.81410564347104,52.89899600386705,107.81505136021094,279.35519854796553,311.5240593781127,399.0738070965356,,258.2547688386521,-9.604071735813761 +45,87.88420083231307,311.9827551049291,120.31149648347217,103.51048325211252,69.21869209153824,112.06290381951389,271.0526252149733,288.78453846365375,366.8124744760913,43.1504369326666,260.06047619150803,-12.481381543059399 +46,,242.28334463868515,,17.081967143523883,40.99961210668816,96.29237747050709,225.54287694346525,,210.4537077053119,-38.09329020678881,119.6677135815307,-35.94100199887923 +47,-3.920275964665123,224.82552282348942,20.920563782153522,28.738041103151588,58.08726813819797,112.3542736410262,222.04095898337664,193.79633268352313,164.0701172051015,-63.242003607510824,85.76113536567962,-28.8388782286788 +48,3.2262218419940254,251.47893988054017,43.5245781530427,82.46583063655368,82.27605739836471,152.3532365890202,245.01553535339644,182.4316336381263,196.24537619542448,-67.85547092184234,206.18603821858463,-3.493710509846693 +49,2.124373804649437,221.06395889196614,,68.52426428263004,,156.89665974224914,241.13397594671483,169.39634810641724,152.12681808757554,-73.9022374238744,253.07373203938386,21.61470641647526 +50,21.150219915249366,223.85606776942586,58.83056506574587,99.70629752159178,103.85142521794893,171.9604493374434,239.91742066005645,157.42014252752364,160.95227493796276,-72.41920819058053,249.54946479002967,52.051622620459014 +51,40.77944762343958,,86.7682180368102,,129.91252219316252,171.63088990924427,234.41224564065453,162.03865001512992,204.1492208816939,-64.29233498745128,251.5964090017958,79.41573965169167 +52,48.22966436572857,327.8127138786595,81.56164504281172,145.24408896134523,135.15126523798426,173.29315887061156,,165.0719729080512,,,321.6645771857913,123.03078173350481 +53,120.61690689892592,316.40086160399284,201.45930740784368,251.50838450790243,,187.9703498658829,229.52967133198757,228.35744591580436,277.66756811938205,-21.144367882725575,398.5758699247986,143.6148915081157 +54,208.6038434374749,397.1726054278074,318.2844046117546,468.3004480017964,238.3224949342239,216.53840059845356,,340.6926173884726,376.4089139560837,35.73653247717317,560.3627436885515,190.75814326261448 +55,56.10852328545971,319.388731305277,,232.7785906068631,182.86969283087126,179.63099055103515,212.36379541602582,196.59164472754242,283.94543867959044,29.655042014651002,461.4234455952509,201.33630668670585 +56,39.17941596818724,341.7560761337597,170.092497525716,250.66853478872656,,174.05528918150287,210.16777427220038,159.15751346050024,286.602685462135,50.995230051506184,453.72703428427917,205.59902559470834 +57,12.282230085696611,264.2427924305057,137.98653323105276,252.6611661046444,184.9843412672884,168.96832728141095,163.789740654862,114.70388591830701,245.86419654152985,45.45707662526305,430.5679815039892,171.6368753903642 +58,-51.65103378004602,230.4241397309065,,169.13425616449598,144.07257509612322,126.60231761869018,111.06247363362122,46.631419555408186,121.72271377668756,38.86852482710118,342.0654518990314,155.08088085623658 +59,,209.30488882104117,91.36764999993558,181.38682281294587,144.99926658571047,,98.94615938652106,37.0942972726742,110.07774385532502,,,124.40597977703126 +60,,261.25551263107224,120.14419216502733,225.70522836159455,150.74283330139426,134.67997612677306,108.35528895996381,53.38677590524128,179.6045871451727,85.25573970276476,458.32061348470376,115.03135976571356 +61,-83.39126700276157,270.9285566449587,109.23344842333364,196.11501939956105,144.98467533455698,131.8370568437007,89.5464971638898,45.51502192765558,204.53308712836065,97.4837107129506,493.9077986234868,90.48806193190055 +62,-71.55543633339929,272.5846349825465,,232.49278076127797,,128.93905766242975,89.7806396089191,61.98029139968253,,125.22965524503343,458.3755298303417,73.5162445328661 +63,-45.26341677965712,327.69662816912614,171.52013127228355,,151.43354392864615,103.23511675853766,83.87066381043131,76.5149047676897,299.03825612646375,157.18714211146926,431.61599139186296,60.467701659489 +64,-30.24793284860661,360.53899268340615,179.52226387757722,270.47713803737224,134.48574982341407,86.61790270214428,99.87967845022585,116.1804611216806,350.2877514932055,178.5381613085191,429.2956278324089,54.49557570572105 +65,60.798881390774426,,271.48007590588054,376.26558604342824,166.13976467410794,78.54816088891153,96.12659289815589,208.68262867653092,430.47865174226945,204.49755251620644,506.66425079380133,30.672858434408653 +66,161.28255499146846,489.31848490419185,397.26039621146117,557.9128761265697,226.25822715668812,109.63153877736298,112.84348205185651,343.42989086396193,534.2472493473085,239.8898356261254,627.1515572945323,27.446554102585637 +67,17.413203400153606,431.22517719455675,261.38720274965294,343.2064375595364,136.63642055943123,57.65859658096001,112.17581755897557,257.56707932914605,482.9102925043161,215.79835849197167,477.4015674644734,9.216715152184577 +68,31.20841653455159,,291.21613031821806,378.6036626837638,137.71950053683344,,132.52159020126146,,563.336925874782,213.3459481038922,469.77195560554054, +69,-22.068169584909924,397.17201788578467,240.06550854215874,304.8637707975173,81.27392668279106,-0.6679762781022589,91.07527261697777,208.17538447051663,406.66778789417526,173.03405062903983,363.9570788085363, +70,-72.53505416175429,370.87678573709894,188.18095296702302,248.5015479826154,,-20.44089554850177,65.70810997131532,163.47573332151947,304.22869669604984,123.16129824507749,226.82318024525665,-51.93223987709534 +71,,353.3883648152892,193.44413057420883,244.76999630963616,23.005310521724546,-25.20320752936415,73.31204713551242,169.2955571972228,283.77735731410615,96.61055667502713,192.44878410338316,-61.59461685873174 +72,-92.35689638405748,394.2777610953915,225.1619483038391,273.97203970858135,22.985090596362497,-3.6961080945401363,,186.01880184840218,359.2606877550411,89.23583467911112,303.64136437711363, +73,-89.22032735164044,407.36791197237454,204.03360776948176,240.5015214419725,-1.4377790821071201,-10.512626812274846,104.32735287757889,202.3002314197903,328.7365435106419,55.28682342887023,307.7362895495219, +74,-80.30824148787326,,224.66449732245363,250.04886302449057,3.864705925743621,-1.1789304630115396,129.66395960324002,,356.24952523240944,44.87400457754824,249.23592549184139,-6.566481003467317 +75,-54.90233259843677,498.5192570664144,,271.40967356114015,3.62538195563441,-2.360258287755414,161.15656025131034,259.47301920404044,399.4479136246562,52.91891397171732,237.28913196143964,34.446371480383775 +76,-37.75745045615432,581.2979234672455,278.0268618742776,270.1637668324553,-12.659251486643399,-7.778658922129225,180.2365870908391,292.7172444216353,431.0725830665922,48.556483893941724,303.13189890423166,78.728589576275 +77,54.30723841989186,556.6211711972882,356.11972845705213,340.45258912472485,17.815001609139458,11.183028594304702,186.06401030265957,379.78370405992837,449.95503720471055,55.350028639155155,384.57398159700483,109.4022242110454 +78,173.5632861529246,630.1842987526522,459.68461038295857,479.30892401669337,69.73807347173253,65.57016132567679,,488.35792659793555,,88.20927527003671,522.0697819180757,146.72277863281224 +79,27.468637047592466,603.5899568127384,348.0316353328024,287.37045759259814,0.28611710877876817,48.926520051312735,241.88992611255293,365.1302230850633,,68.92039473585601,446.0182137030126,160.56535443294007 +80,10.69251790189594,,340.25364517128435,252.80576008637152,-7.148672253332236,69.76450104356879,258.3834885369376,,389.73439125619586,74.0762772018507,, +81,12.468423434946743,537.8815105399826,323.97810430104107,264.9371112698018,-6.556414572569523,,249.42583753511752,282.8862050648454,280.0633342941316,,432.1569948258721,160.47386483989945 +82,,493.0053490193818,218.10903282116294,155.06531313778055,-57.857980257276466,57.73197833726465,239.13775429832316,191.32596404651704,140.84762775784537,7.813584088317519,,127.62877319264045 +83,-56.45408672790876,448.95726570159366,202.92870118442423,135.46436878019696,-66.0359688061256,55.26680371887384,226.91406132322493,,94.3931767801587,7.56308464881036,321.5433813565723,109.55778795502643 +84,,511.6561102451385,233.03726897670165,181.4696995283003,-44.83416750338,126.1984813304295,250.24826251567958,142.5470467079012,160.6815478446955,33.70520769647068,500.85326855433067,102.56541864271635 +85,-47.15857078857546,502.56432387802454,205.2699348195347,135.26836782353175,-48.54744703914662,129.16967311969574,245.9575117542364,127.68795644228732,142.69163883579364,43.275002338496435,504.6965957585388,78.28448998803161 +86,-17.006046716538776,514.603311678906,,163.90176448342643,-28.06086095184473,157.08766413660587,276.6285858072224,121.61664086797548,176.46365548392697,80.15736626622531,,72.59043161297313 +87,11.02281670485516,558.2670744438167,225.43283167392303,177.8323068361647,-11.021543585456115,164.27597671334487,276.4353508509626,,238.10710314268383,118.69104574301736,477.71255563967566,66.17695181526872 +88,33.268593290266494,630.4525126970733,,173.1538987969542,-6.212160964121296,148.41377320134953,289.7998528248751,123.65759279910272,273.08614394478866,165.65118029402612,525.7342383117553,70.403892867653 +89,130.88020894105412,654.710120653681,,241.02052588837458,40.473724289383426,176.52322092512065,271.59666873712314,199.4069950530113,348.07149994676973,205.96632406410694,614.1908913376715,62.495211990603295 +90,250.42485437744426,715.5145982437406,421.4813291659026,423.19243962494534,128.9386131897532,237.81614851088057,289.3335918751541,282.8817734741126,459.43842003631283,259.91851276192347,719.9470576630292,45.22627204510957 +91,112.94266215533284,649.5387560341304,279.6328266573239,196.22544248278695,57.737581519416636,203.48942326917975,283.5092810562507,157.31113901083964,403.1523696569012,263.26132239168317,594.1492496106712,26.604443386828436 +92,92.18408660538739,675.3620826883458,,177.2333972411624,68.83125769969068,,,124.4274721679651,403.57208393497837,265.0195052484836,539.383563174275,9.51390951529288 +93,,607.2563687125105,241.70120673395454,187.92450089624728,73.23959701136667,,218.11591119028702,82.77575775095235,,249.11438499405145,490.74544019874554,-14.115618428654244 +94,35.09043881201346,535.5648159710345,132.3720205676046,96.97917325627738,41.63768816169751,167.54711358348945,173.21862823439992,14.326472309606686,221.08820616734496,,327.6479415832477,-63.25401137592188 +95,40.120069586868624,,138.91718989652992,,43.7414959748841,175.9414338821632,143.90503210842868,7.787426159777112,202.9129352528335,,310.4960737016622,-74.31181722921008 +96,54.65039421758523,,148.83391199042117,,66.09948501799383,,159.66724243262325,,282.97325190260057,,407.92690861079865,-67.37292838391392 +97,48.90485330286329,530.1693326095208,112.23976998603989,,72.63681724054659,,115.89177411547672,20.777696583033986,269.09107644086777,173.8923315882547,394.0460823324714,-47.40816981929186 +98,88.01876231592871,535.8216726200199,131.1228906993824,146.12289111665675,93.20332531508896,214.29076341932407,124.55433417240903,45.09754966815989,347.5488558184219,185.5845825516624,377.2024700339626,-12.10549791245873 +99,110.76461201572462,578.4677744112118,132.7368227330573,152.53943315592937,110.72982650651966,197.05987529973314,110.50708354055823,68.46114015468824,380.8742438051952,180.44195586456362,328.27278834627907,24.62078936045507 +100,130.76820413433867,604.3850115753203,,171.3369391125921,131.102199011868,177.15375647344473,123.36374315295578,126.08235207589486,457.6327980324498,191.3740788249341,364.89725890824076,79.63147589968389 +101,224.84718084994287,617.4574438267839,201.34706063768724,,187.80156756946684,180.95851133380387,108.68546477183409,212.50713373242837,,173.73382512925474,424.0824276113444,113.88232605430957 +102,339.0497811619692,703.8584961340576,318.3441386103,490.1747496478347,282.787749872064,213.19086832657408,130.60529603676594,358.38567426036667,647.8609038467291,188.2006317006596,572.7274724125084,161.1150096580065 +103,189.25452313212475,620.0001839814402,163.48590714196388,254.67566507777292,196.9577531944496,167.3889652706843,127.17917039573273,236.2706870872554,520.4479755207942,149.81916312142437,479.9199746684373,171.54726360489258 +104,216.36352845689697,644.6931261769353,,283.3989830144124,210.1516723734479,156.20151633033598,143.14342821663618,249.02699287954331,551.5508587796952,,491.495844476385,187.93056099470607 +105,172.91814290810106,555.33420869933,127.87050997277873,245.528145599672,180.84725166564237,111.33621761051612,109.25013863498724,186.713764502497,402.9163043504809,60.984724179550916,416.18256795231946,157.4899968405436 +106,118.744328014503,464.37744634791994,21.045430414354882,186.75842307299584,144.83474706954223,59.23873101120479,76.14445206369737,146.19705954205423,235.13987657201977,4.721116502519614,335.4839609546828,130.4676369112633 +107,117.34434638582786,441.11289960203953,32.717859498596255,205.1380382307245,,59.472689736603925,97.9266450044191,147.21355487177388,245.371483515816,-13.81818502657957,358.8136870850168,121.29670867455525 +108,140.10095753916517,496.83648297815637,56.714491625285916,254.25647559973095,,83.27346224631941,130.3300805205995,166.8217427664206,244.01238211384796,-13.05236606413041,504.5113970361652,118.96502645271107 +109,128.45621549610064,441.45483911096875,25.062550789922682,,145.7520367802772,75.4382301672851,114.25795375035051,174.3600929845716,196.80012955780802,-14.278058914264193,527.7832821815738, +110,152.05116767408202,449.78652908906,47.090835263287744,,139.0659852192204,65.37851785463329,150.06203807367302,194.60423507586327,218.31668848076052,6.049329523328346,543.7523030435369,106.69533526337835 +111,179.25099514896436,465.1343421645922,67.93345640381594,298.1074373045494,154.8239288251321,54.02352454400358,178.56200869734064,227.58782208566004,262.95536143659115,42.95472675492295,520.6141311913881,113.19015833329746 +112,206.3849473580332,513.9964558128,,315.91781387634734,136.02254149448171,25.755261856217373,211.7642402277362,,282.42921560387106,75.47564117820758,,112.42127755045867 +113,274.9951546985061,,120.59067831320783,,172.64344936541477,37.30450781942643,212.92923896347318,316.7686360994534,304.68088450338394,96.13755286425447,666.1626069150196,98.19047138114571 +114,384.9151070314955,572.3414080256281,,580.047350248934,250.6571611941414,86.60871567208663,264.96976280031595,,427.833643602181,160.19386187163553,823.139646496576,96.79595735690776 +115,259.45204429656576,513.4161375695749,,,147.8603017001807,42.39281301015278,274.34485447422713,328.46744037122363,318.4411259206576,167.16948267757422,709.9920156850153,66.95597140610926 +116,239.90697430578362,517.2804139852909,132.1898718273705,379.89097727989224,123.20506164271144,39.34365606751875,293.05764185163076,280.3305755618509,288.1416838334973,178.46843092375548,674.6316268526189,26.88563580508054 +117,214.25010416466574,440.80175595404444,100.70107913907839,,91.8761338462129,16.13831305698716,271.411934556475,,222.5953309212182,178.73908795326204,634.2287868681672,-9.843437131307951 +118,147.85334659301634,365.86419604734493,8.894004657595758,283.31513677514704,29.65404422145903,6.971685987286179,243.68106895491172,127.91596338985974,71.87925254517754,137.05604900831867,475.6706019778999,-55.86764479303626 +119,,337.87154710546105,22.543859924884515,286.1968477048378,10.14007676811822,15.86711889711583,240.59172465804173,101.05514665841272,24.47355716228992,145.75207670466895,424.71899681884,-76.58043669723227 +120,135.38068816054638,345.2440784455153,60.88746541298258,311.1997614200036,0.508008909339452,71.24185075433792,282.13262388539306,90.22846364080661,81.7028946215386,155.61311981334057,529.6477657661802,-77.58247912137212 +121,124.31095334749172,349.68960257227275,40.05937566828962,297.3644231067133,-11.792499068001803,81.23594969074725,272.31229282115555,72.29696017506063,70.44634741645837,168.38519239613035,520.1110747998816, +122,143.79844611737553,364.9787566070719,77.76118046592137,305.6260691040102,-22.546378292535934,108.28647394476998,,,139.94827207774614,187.75547056796023,456.4458852629072,-41.04245159643297 +123,176.07073215057915,,104.16273528612747,333.7183828660599,-7.360920256103668,92.04881907111783,,92.57111670449201,208.4422152516856,224.42890124382978,418.90610360720643,-3.9353421176689523 +124,189.22257468757988,,149.34830351224835,338.85807559433124,-23.18159514334498,95.61555248331926,332.7660530004864,111.30809086210735,281.73790044394445,234.31276887278096,451.2888011407616,45.37014318135237 +125,256.4227686246002,416.96023681273704,208.65391264488528,430.364436742182,19.12027609823781,118.33558122298503,310.4807632278242,172.7474843547144,358.8182722613646,239.50146874680112,497.6328270151405, +126,402.9182561859288,493.41592452267287,344.0792194173592,,92.83057926228724,180.31296349544252,337.4010370281948,293.967651430065,491.2593209891394,265.9034170748911,631.6372225765886,129.5049957028051 +127,238.71853794205953,452.20305502361487,251.22640691197887,,-2.1903647850540153,189.56685307285127,336.6844093831895,168.54190423998506,450.4989801111377,243.84083344593554,,166.9299080905383 +128,,500.0343737635326,263.2842393108515,367.75122748129115,-2.2608939159304953,195.33898901371026,,131.8974799009186,448.3786806514318,203.93750575202688,,192.35263098985834 +129,173.6026759149479,388.87824709046345,224.30083730074742,,,192.46814968953345,263.9109978055944,79.00422671150758,401.1797925796765,156.2763618822195,456.828387685105,181.96498172069124 +130,111.9321049442778,315.715483262377,147.3013427956925,247.77297669990014,-59.265297482291004,170.47602163590656,212.82764012417232,35.89190855995727,257.2571664378398,,327.78474866562374,154.78479509785942 +131,91.10486195674906,306.84408831852835,156.11847629425603,217.9881570711172,-51.75403107476842,,183.81354589019344,32.46829127370716,250.38680915392843,52.743962773684274,319.4116233351782,154.85609967485917 +132,96.84751708257801,344.97862158416615,191.86380956384565,239.86322170333924,-54.18294347864685,247.49149999126428,181.36358355155596,41.7808220626176,,40.30498731891328,494.4865731013461,157.90135467402493 +133,,356.0902564116026,184.347977571061,193.68725211872015,-50.43480553488686,259.8208297800775,153.2356516593724,52.51558260243048,311.98427074858,,511.86273787855964,155.39522385594685 +134,101.671116701806,365.91896140163107,217.92149308165975,236.48303958396235,-30.955156736657344,269.42462609418965,160.61765771604104,77.14787167130734,342.43948406358317,8.515327040108872,484.58046718148944,156.11076385383285 +135,121.97249425472617,395.56212656211625,228.21316885838388,236.67095868138605,-11.458254083532083,245.04871302256228,,118.7847807143352,,,489.4647075585867,147.88122869900104 +136,121.83137225090792,478.9986568416986,246.24217666918543,235.13353853017955,-12.709383348606835,,,,449.54602374672976,29.073719490691445,562.501284048653,159.47268962134808 +137,,428.3025922548959,307.76453726847024,316.05424197462423,52.175123751934336,249.8619489842725,139.5998847063649,259.14172440463454,468.96930590091455,37.117304761508294,,144.5081945766862 +138,327.73954606241756,536.5295562233187,453.9084969649398,482.8382434091801,,301.4624498650025,,422.1908771805479,591.4866947375813,79.86260818083575,825.6929238451794,135.76408635345499 +139,159.35319205932,506.6847216461493,337.3067851541464,278.46198525740994,81.15954120477889,,,331.3779492333391,486.2388112359546,86.1059175091339,750.4063932076712,120.51149420064465 +140,145.25550286661613,516.4473336762787,336.2741529596762,252.15199962456094,,265.5866336703525,187.02425239551857,327.41437071432716,498.2362004322176,88.6085919098655,760.3577546060992,88.89709103946942 +141,82.71300684182648,429.45516968742584,273.4867645737766,178.4812153998133,70.09361249421207,207.42251357357196,,257.2709430741493,315.2948445260673,47.33643696071505,,34.58665278911217 +142,38.19719722956492,376.27785551439956,,108.19994703894459,,169.77162985968104,92.88527170712764,210.56653446033226,180.6914886440016,14.832041822352295,527.0018037528178,-11.083742095358161 +143,10.350422471564144,341.78425867902774,194.35842363654808,69.44959965925673,39.390094074735075,150.411274292431,65.62530431608965,202.78942469108742,,15.974243216579445,478.9468948220405,-51.377844161463386 +144,0.48779770456539495,370.03770597590676,211.2897290965965,,60.849187016724386,183.2057926653991,104.4395195366694,211.65184494425523,173.5552199313575,50.05134375846466,638.3795080629525,-54.71396076458686 +145,-16.785354702958628,,203.182348028734,63.41644415482277,61.08422569672095,165.87380732093243,84.08285781496348,211.0693216258378,127.29670948314421,75.87707309717521,614.4902796210419,-57.746494514620835 +146,8.461871028353869,418.0895068983458,,71.2817714415103,96.35294366839454,149.15496343444204,126.82026122049965,,152.90636011458872,107.26861084367673,536.0522127303456,-39.436825404853195 +147,29.80023344853447,490.9293636884285,252.30166661979675,119.38073341869762,135.04375427264125,140.9580457599755,168.99293383991295,250.82577824347345,,157.7966211535268,,-17.094755456553983 +148,31.822596493980143,534.086660539441,248.17192430891114,101.17109137298247,142.27903337571252,107.72100575086607,178.05887505687224,272.2033892017145,,189.570419206716,511.80379273544713,35.85173818363917 +149,95.24217057252011,525.0110047947285,308.8763034288177,207.80510401792705,200.6383027650494,93.45910147008252,167.8391868229007,349.5863756440716,279.266063744438,228.84494166304097,561.8960714912993,60.88567142745102 +150,201.78584261343036,617.0385808102199,420.535568521029,399.2561160229123,,148.4466114238359,210.38384071754615,464.73580242470155,402.6138478734034,286.4292330936836,707.5530487990254,114.89143901745999 +151,67.5151463933762,588.4442775069433,287.3993788508073,192.32393250120953,213.24823854568243,91.47973758207857,230.46265860501921,332.6745092117807,,280.4791289437456,549.3900827961357,136.02860734055744 +152,43.225881190221756,615.2505960035492,,166.68277619835192,214.76292727285403,84.46037326781128,260.1615407826965,265.87718691509684,338.864940265943,277.0647728294761,525.2998503201261,157.2880341736188 +153,20.478918741475667,533.4266887723467,268.0572475120205,182.125688871964,207.18366916503956,68.37413590494549,,214.32823725501314,279.96817238795194,250.196600386216,468.73258039314794,159.59571753370523 +154,-39.07367782142816,483.2148622375343,161.53468284229803,94.45882319016246,162.50458153016854,33.64246555183111,224.78921531456314,137.39001416571574,154.90682795853184,,339.90114429376496,143.72879056301713 +155,-70.18168625869751,441.3667456604586,133.16420383897858,55.432961853291914,,17.058066180834487,199.3342795448896,,,160.8448126593467,307.4189559355427,143.97701515506594 +156,-58.05224633178068,548.5149051341668,154.1685768438051,112.77851080834014,151.5339270510824,68.7234348923962,229.972680738902,,218.56478304932293,166.66673915546093,496.514037650261,150.40798923553433 +157,-69.21349021774239,513.1952570834022,132.26164798283764,,133.4154963269262,55.287414848602054,218.53389240295533,65.24074337481315,215.25316889089322,144.1256895240143,, +158,-51.333511372337114,557.5453801044067,148.45268279844967,131.78496963889572,133.36879263690878,69.95512426966978,263.4393286059226,60.5737370730863,,141.9546072750852,428.4969661638907,159.19639800364223 +159,-19.184340027559983,630.8280522404352,166.93203275709706,188.91998025731505,,50.08952591373985,299.0061921845742,63.65726941614419,360.2584549652436,,423.31625013539394,170.97574745437177 +160,-16.034666788981383,646.915047195977,150.5520329097861,189.0004606454392,129.83809067326092,27.02309254222901,301.31605921027204,80.98680720355483,416.3489113168532,141.87694220955777,484.22604740359617,173.8140959328774 +161,,656.0474003979617,213.99555187878917,324.314127603453,,65.41891973982291,287.0691889168907,157.45234654929305,486.7171679168127,139.0232711816065,,158.1746736861506 +162,,730.4772714344095,358.0186340734516,,,122.6881738419509,323.66221076641096,303.1687889364138,621.8904653320246,155.29487666648708,763.0924421815209, +163,38.075317270479076,682.7949977152563,192.66501665477887,305.27120958488496,127.23288032362676,89.05893585411367,,158.85640534269209,521.5983112356321,119.21896543784254,647.6094131263258,129.7750529372726 +164,24.42430569719187,697.0145656353158,186.38535711270436,306.22865692194017,114.09536045240097,,314.06422584704853,131.1393398923121,535.3874231054618,,668.9226512777157,104.62139313175203 +165,,,141.44871194265085,299.20293698749737,94.73418563664572,92.85301431400617,257.9566568767393,95.78369666217654,465.3390660204748,58.711228854741165,643.7579667466852,59.844719431590775 +166,-64.45900580276927,569.7280162020363,48.467815976444726,212.9905486177923,35.3062911079832,73.34674606714036,206.90331604548547,,285.2724323019034,6.042798083035969,502.4061353531746,-4.022042261088515 +167,-88.16667634952927,523.6327454030863,39.39826625796903,223.32819111596464,15.12480902178676,94.57606531327505,179.02607962479905,,258.49392880246444,-6.5381489074854855,524.2969425249819,-39.54354021122282 +168,-77.5091645642384,580.4366544232914,61.74721456062238,,5.312905073125023,154.13739866445147,189.4536013962785,37.763710008055405,,,698.0294718712228,-40.6032110277622 +169,-81.08219314191354,589.2325797918318,,237.82486837769684,,151.45908536690195,133.38963488209006,49.65067039901335,,1.7454271239287777,668.2492671449623,-53.24952250591657 +170,-66.28037712981197,604.1663600253494,39.09335014745774,288.99871428019424,,178.60813281399726,161.35179447619333,77.39051304932865,308.97366901684677,23.308286938335655,611.633170416235, +171,-20.02829381836341,655.6557917771652,78.49820750978611,326.69896557703066,2.5566623745202577,195.20751909773375,164.76116259513856,123.87891664864043,,62.76230960599985,,-33.00749482763918 +172,-20.54080365433765,708.5944846012464,75.38908397529889,333.5684419890168,-7.522237733336667,,152.7182964353252,178.9810467789993,388.9972840579092,,591.2062244483504,3.9663729454665173 +173,70.50328290528341,670.4130970208096,127.6769202601019,412.85623452844214,21.4990148460562,220.27866926689353,140.6853391163141,289.20750625784024,405.02689808645823,116.39364151492552,663.7761192317403,24.61561440900698 +174,193.96215115164415,767.8040394744382,270.23160672841874,584.6744318829882,95.84252253174753,284.1013514880319,150.3916074573105,456.5557047861678,533.2269772229808,171.70392350488322,779.7635846284219,60.05041260234385 +175,45.839656652705344,710.9168169351458,121.21995033139672,396.0106756954011,19.84251627207408,257.92410886985897,155.90106147947907,331.8053753028686,323.1881900636123,188.5255301209413,659.6898245145499,79.45617628795819 +176,,728.9385865719739,123.405793665249,377.3949633104489,3.1358637761883017,276.95865419321206,153.50408837516233,306.4147242330662,220.79177388573885,184.84592122808112,598.9957470235815, +177,42.563684859529886,650.0782576413668,108.81069923199749,375.436920951974,-10.038405325704161,262.32792627400744,101.90188854733725,269.7816125444902,165.29832871220464,170.64707215131781,522.7534853890984,103.23909030582588 +178,-20.436906084493444,599.7164877152597,6.84413730502331,,-45.38983350161223,217.88306123223293,56.09321418632962,197.74212397089667,,139.48595118374837,346.9992381696571,104.82270401077517 +179,-27.905698975802522,539.8017916607814,15.007922423955392,278.5829785560912,-46.87669539696995,212.2415945436051,42.112476491359274,191.78790867634342,36.784813255945835,152.6537706385697,322.4096415801582,109.95204992760797 +180,-8.324960632634202,617.4115538331889,46.74245435766642,322.6189846631929,-38.04829572509517,255.44560472458733,76.01129477739718,214.08885211109367,93.46363403122228,167.42203167498695,495.4009691164164,136.33180314918565 +181,-11.46523846718489,,30.931833236337326,297.30057253344853,,247.79630769626644,,211.55268398970793,83.458244891829,163.49032300937017,479.5382197933595,140.99109183933928 +182,,616.3938361252594,,282.88549628752526,-8.634230060112657,239.77440438873816,76.83531225163745,205.2274999354465,78.3080265811584,170.32306750386607,409.12250646517424,156.4341797685134 +183,43.20722009344465,683.5399536203116,,294.60680101380376,12.069610158850793,242.76780912035608,,240.38888530117052,120.63624923295475,187.70247401763305,376.3034290036447, +184,69.21056771765166,658.0001502147215,115.52666370545927,312.2552904679648,21.27588466969287,229.6631444906571,129.5322928661083,246.40270886314033,191.63050401615044,206.26900714205664,456.52324245761884,167.91871081688828 +185,,700.5302114927148,182.3850410394743,368.8435172156584,65.59154425903046,226.54662626186322,117.51502943807392,303.09728001294354,261.13670396260886,208.7244115276318,593.302453091257,174.01468057898504 +186,281.4868297080209,727.2831073677885,306.2779700591272,480.0887504662886,139.14716689082928,271.8451349465785,162.6555480924237,420.3339838329082,374.7471965232653,234.1387471039363,746.8991586923562,176.83371377115589 +187,147.43245106158437,686.7624837584486,218.77806376791537,313.7147824238973,99.9295140023306,232.384726399945,186.45296971832448,281.45631499710146,274.38068737819725,196.17465827221957,603.9226474681072,145.34814662709948 +188,115.13214271734128,704.712119510795,222.1307360542234,276.7976734283241,102.14580260114623,216.48808616762472,,229.86576432829582,,158.31304305809493,601.9855505282937,111.28220088103038 +189,113.48380914734295,618.0357060953826,194.7298581423762,263.25164406847017,86.06056883729336,,155.75322741239586,,263.4685776373349,116.26486836623317,,71.34662521046965 +190,72.30719211698835,531.5080641574549,117.28323506030887,161.57384561765366,73.42528528855775,,145.39155056135232,98.30290161358428,181.18713102538402,43.685883350159706,477.4567253074507,15.428096160491634 +191,71.43073059706629,,129.8488895696933,161.09180814160845,84.67280808845535,102.7704518267892,149.750462254498,69.22446505242189,176.98524987241345,15.75877144525537,496.45247343338144,-19.332599272393978 +192,75.73553621551063,,143.51211671407634,171.93669324474976,93.92103622252188,132.85792653162363,184.08779514240865,50.42717297547743,234.19582943542943,8.211794291543157,678.1223321023622,-33.481992184455464 +193,74.71693558258161,536.8382936503541,133.19240272468892,138.88188539824858,97.89752277122935,103.03978931849778,180.3930977730261,27.857076637167467,232.01454107914813,,668.5081818227643,-46.916220092708826 +194,94.7179484548552,528.5879931877043,167.33593047751842,151.98648079624488,115.98030537041419,107.5054558811344,239.22142414051623,37.35401867304103,260.97146858293587,-11.936227074229762,,-51.787807936253074 +195,,,217.8085901014224,182.48023904543865,159.93215121974032,99.15290864514365,282.2694970567174,61.033522330183246,331.7998744204043,7.836058270171776,614.7248448636142,-39.94716574048603 +196,144.07998888099007,638.6030067193034,221.8756065050274,158.9504014306619,162.3768725494087,79.3027624634326,297.9295058431828,72.12569235043196,358.76401854132376,16.390313683196197,624.9011939922614,-18.32589276421993 +197,226.784699667261,631.9260760619696,298.0801487815389,242.37902834076687,207.34296285055248,103.36921007559394,304.43280419457943,153.67263356495513,434.00042144461486,,781.0752305940262,0.7581759621300819 +198,352.5873606375453,654.8088785473565,437.271047558539,370.1879030259056,260.2672047636163,133.16969204342791,,275.22733178267447,523.2483452198198,72.0169509593992,856.8862716841029,40.82760569105899 +199,198.25586359498914,618.0421325755888,294.62414085345324,232.26620445527828,221.7397327361993,,318.4966037160262,172.203889248805,415.4563571526851,100.21161387007456,724.6754297505875, +200,208.2785555284259,668.4261224611423,335.12036319081204,232.28174232099644,235.2897827893277,128.13285690175996,,165.51196058795483,420.7110111608114,96.67758826121916,,69.42075541055145 +201,158.85557383121665,557.4620783606135,247.30324178306785,165.361868946947,185.74419986673212,74.16588840165866,272.2783424537407,98.92217913834597,322.9798442844814,91.22376637969278,600.7785175024142,67.85229348545433 +202,129.48927218633358,510.0404662574122,185.31954664777322,107.96113949931684,162.21461200756437,55.16396561416326,260.8778388581564,63.39381270219613,232.48620744510245,63.25978260442335,429.0884347617864,69.7105354182253 +203,112.99988106056955,459.6995668995887,178.1974609069834,95.25087105002443,142.22399213644042,60.469754755579324,224.27804646427694,61.905684383701015,181.46920192128064,72.99834144132508,416.5979200389903,77.42157145327738 +204,119.24992859692381,,216.149243189814,125.96824066365994,,,222.8649084644274,78.4499038111159,208.6524549733934,111.00462222095355,567.1114377613943,104.15409314221795 +205,121.01917863381973,505.47016732599985,190.13147657533696,102.59552429531948,124.64788835468833,111.39388959223052,186.95492453161225,104.92645293464585,176.24011007061065,124.87864258621661,535.7225338185191, +206,135.12775160629695,529.290057342677,214.8468814098119,,,130.83676062531737,181.88944893971419,130.88309025806652,188.31569877744218,143.55901049404585,431.39447293184617, +207,161.9232225213641,585.5825240802948,255.54658701949288,166.13692334312967,129.93254735397068,,211.07180965394704,174.28832362809246,246.49569064366102,,,160.02741728942246 +208,166.336423544004,635.2475280386765,,186.7583836775905,123.19037807516837,152.0558985234784,200.40548414614432,218.14216326600956,260.6511073323346,203.28894147098762,450.7238240197819,176.71731627456805 +209,259.49475385718466,621.713860193631,323.1661544549112,,140.91102271420954,,175.44766380908345,,292.0833685868904,,576.9652848230996, +210,398.20612993712587,694.7623835288412,446.07274769865955,434.8129787233305,177.60521323053857,247.15104497556092,173.40704349315595,450.8141435907,403.4687875488457,268.289283215176,702.6051495943348,187.963174571891 +211,236.17502680701648,675.3511414815946,306.24034214276196,267.62503961506343,103.567718769359,208.6561822263807,165.55027755823187,353.239037135776,300.8637689471881,263.0248018230219,627.1200551119507,156.48811180751233 +212,203.66551691897558,658.8862998621668,287.9375970254629,,87.33161282071714,224.40245390842196,151.95706150396285,320.58948657921974,,238.72390339009615,602.7003626435766,138.716883769525 +213,216.8689539534285,586.0968653081873,258.1452882593784,,,199.33256028391963,121.1571035797696,275.20809871629694,224.17964973477837,203.76991000941945,616.6510624200093,79.30074349067309 +214,159.27178754734493,562.6691107345274,163.54359615640962,188.65509197250253,12.840909730196167,181.76355915761417,83.82566255979265,214.87466329287722,119.9699829062184,125.85289654218228,443.4728850129042,22.801829751269796 +215,132.51055763658894,499.7931761228439,,153.76211420745477,-13.230428628224189,172.07647577657252,46.61571090314868,188.6272430015881,63.50036525099796,90.06074648662565,465.68868316621445,-26.761819158055438 +216,138.6402721623516,536.2003653416672,162.9773124384937,215.12278122049253,-1.6116780476119175,227.5920040345899,61.4469011731951,196.57433930486428,,81.18943547476519,685.0436500384668,-36.688810162149 +217,120.29402653217507,559.3633125129068,126.1229230660785,193.57597527361008,,218.79372591044603,51.33503120398842,184.6900769561672,136.76913383679494,55.3566173116987,657.9627021371592,-58.242981363641846 +218,126.95906334804911,591.0292040687536,144.01645234659222,250.94003477556564,-25.14057005797588,236.1626767581144,74.47064879591395,,160.50692550483888,42.15914542035173,,-60.374428350239214 +219,156.48546038612676,,187.13852408629967,301.07022908000874,-11.011601702014246,241.10597476697666,,206.69098187204497,271.10091483112336,40.29363456496672,641.6501497366728,-58.32152781609621 +220,169.00465602945243,678.5766560960458,179.61732663625435,301.63521745337215,-9.941344906731501,235.7018035193555,111.1607214873195,221.59383161842175,327.83293101861045,47.03434324810968,700.7962062793703,-26.09477231259271 +221,269.9865230743759,677.2942358958311,256.06982883309183,400.2820080496053,13.677443330890497,267.6252603225005,124.04168390188559,277.4647474562394,429.3005480081048,67.22034642747812,834.0332047193995,1.7715165582131505 +222,398.41986343814796,746.3446912495265,390.6421649768834,545.272516877374,71.70412143966483,301.7891927870413,156.62332741366248,402.0251525048482,542.6916082301107,86.05285382402052,928.8299443096511,18.18920296590838 +223,231.19864523562418,720.3030879582548,231.54059205198013,393.4259778454049,16.716128664783383,263.08169501378586,167.48089948189042,277.13219981954387,488.9582124895877,70.86903751640897,855.8240844232978,38.043686783994204 +224,188.0453062337002,747.9033424173972,210.1094715934331,383.18168049595033,24.371778715111304,246.1284161102543,185.77955971178062,,504.6888581916684,47.43725989744149,,64.57995773791774 +225,171.42077742475638,,178.50096307055054,386.86213381972095,11.764757505115455,204.62042084805432,190.56113501558804,147.0412831420663,443.72113871920396,,757.6182167457355,46.57766708647655 +226,,624.7568898186083,80.29850251131239,,-26.96525144613547,167.1626901503684,180.07270112007063,74.18654871748052,346.6913497958282,-30.516962887728155,593.1142219524954,25.653610597835858 +227,80.34993567857327,561.5198607819198,43.93430996579039,271.4294609260751,-44.58660629450824,150.9613716205173,149.927055305916,38.132499330312314,,-43.444207282371096,535.9060969645138,33.45787820722922 +228,89.23309167208296,619.3021835249735,78.3516371639335,327.57977297616924,-23.54419118042118,173.9971749842396,197.1806564937132,35.7894413958395,360.3884611972513,-26.500465419354626,,61.66090031745341 +229,83.4481904240412,,,282.366699665663,-18.2079251235798,,196.15084847470843,34.473330112963524,321.85787776190426,-15.087612374004934,647.4698840094068,74.96308990689029 +230,,681.0884855838533,73.72993025356345,320.53437551559546,0.9068849608391787,138.41749139769746,,44.80756306697187,349.03810682188276,27.106420840059123,570.8891421272807,102.38570894798472 +231,128.82689669750675,,107.22121400971875,357.8814258613734,45.82664335449167,124.25882585350078,,73.97732768126221,405.5232543989042,65.07036438794324,559.5152074680386, +232,133.50678996402948,,97.97120633091902,355.4146936555148,60.093984392858125,,296.5300490162443,113.93653727623942,432.6413906181781,117.27959222469119,612.8993373615048,182.8803742373 +233,228.66911567307386,794.3906354241228,199.59239360442842,439.28062308625374,125.65566584996003,113.1778461761991,,206.2193095931756,525.5253794599726,159.3042694023382,717.3167641418001,192.7234996644282 +234,,872.344523268957,310.28924173407904,563.7382496183664,196.18220385843264,,340.83517635426404,345.76386748373375,611.2889436741219,227.99444476591464,826.0997386124474,212.82054362675638 +235,175.104754048908,,,396.2994296231538,152.5669349583241,104.98694471739574,343.0692239796379,253.80006550060799,489.6370840912301,236.1229333996475,704.1953320283629,200.16739435702965 +236,175.94263268201018,858.1891153425942,207.1818675895649,398.7347475307079,159.34754842444656,95.24890470191542,342.9701169868406,228.81582470227687,474.40459294887535,245.29846685144997,719.9863916441941,201.61279349467952 +237,104.34695927401977,806.4057563130343,129.41296068594,348.4997416938926,148.63232389314848,37.01140105001451,315.0321290068329,151.14767260105788,329.33159693411227,203.66627252074508,629.5301704953179, +238,51.77697992560924,740.0871100914271,32.76461795294378,246.06203050749144,126.9053695853805,-15.618805890899054,259.29768024500567,100.53479499327781,183.38789857859555,159.1375007833245,479.89115786502464,66.93863006976153 +239,44.826002876138666,698.4126060042809,43.53770625566693,236.81880287567688,130.5013485979574,-26.585713493776808,239.9090714293081,89.85397924594106,122.29601125330828,141.13609572669307,474.6040787960297,27.90926126189658 +240,42.90263236531844,758.9734027166019,66.4883997994798,248.66230175550962,151.63801676095542,7.735852206174684,255.37523891970218,,191.21434900789583,,,10.624050524789517 +241,29.339698679588096,757.1244328341069,53.085101874705686,208.61662920811125,143.73698610112564,,227.64275109655017,143.20329889004148,,129.56604431345943,630.7829983383556,-19.863026501136495 +242,41.3169122064054,758.5478393604371,85.35525196389577,238.2487378828995,163.0370106452741,18.088037786367863,226.887334440405,165.31154701092646,177.59161400691963,140.45063203617667,608.1087318947748,-22.537260941874393 +243,54.756399429744576,826.5235364181804,,255.29465805502872,188.1674269002716,21.00308911085716,248.16271933246318,212.51005426712595,251.55740862957782,155.81376332498846,,-21.248380430518722 +244,59.106511591270134,824.5559914363018,158.96821684845705,255.0046418964152,212.70249983630083,25.58218322755232,222.41444043775618,264.16426663196296,273.80589198266523,,758.244700568796,3.7352117944926704 +245,147.86420184574672,849.0258139730076,271.43168489907265,351.90243257160796,263.44198352704996,77.71785915055071,217.57695411528215,366.19145277304017,341.6618561290439,188.39968628279695,885.8436255675064, +246,264.09909338761145,875.7562759034819,389.6176755352584,474.20614349159246,326.61872445946557,122.75166231615002,224.46933832968358,524.120195863784,438.7934681331468,,1015.2230608529892, +247,,842.3800188405089,258.54024722812346,288.33623553351254,270.66331637192025,100.77051023375944,200.05308318553918,396.18183776131,382.4582298552771,153.57153918100803,,47.58657373596397 +248,98.21246091861448,872.3974762791375,,271.3964667561346,269.7662770888255,120.12453563596151,199.26149307665878,371.3804229141068,370.00305714553303,132.61155343971757,920.6035425252899,52.728232974692276 +249,82.88156313737832,819.5708078326323,240.68605917577025,257.757394541067,256.0611951518079,105.04226373087958,,313.24399723314957,327.73708390194037,65.55655320171854,,55.76465425015378 +250,18.464833290154047,754.7581719416304,136.47724941033766,149.34880538624805,202.25166229469127,82.955806985224,101.91374842290526,251.59595064260532,232.16302426640797,-2.035886397750488,698.0163086344905,25.42237304923833 +251,-0.5012540046056984,,150.484858751344,140.0187200538756,,105.1561580497487,99.69106889099214,233.67031438892036,212.19393458982705,-42.79927831788493,675.4679571840143,31.55196669302253 +252,-0.2465339809004803,726.5455724227736,172.98840801754085,155.5192850283011,185.28935590782584,155.08846394485803,102.89627412642443,226.9038171862859,301.0790798283164,,811.6015500625308, +253,-13.06840068583628,701.2692805581726,155.6455787187735,103.69446004764343,165.27635371150944,154.23908705475813,68.44277387265893,212.29753307551545,287.8160574001511,-55.49870121576359,759.6981080206264,83.53818531524715 +254,3.5170493878822526,707.909609156758,193.93133507471123,142.0653294065988,174.5723850697088,195.09093845138796,91.47024582894727,204.47815322659514,354.1004971046975,-37.620206838964265,715.0268910155435,115.009835208731 +255,52.3437783926762,743.6420945580057,266.8259489629464,189.23301258770041,186.13125287198088,218.85342817008785,123.90831647915832,230.12174981638773,455.8525484287541,-7.411208582165699,713.4788899019862,161.79365767391067 +256,49.24898841450583,797.8946638310397,287.10637186982564,197.80622238022647,177.61129715589303,218.7299854233213,132.3668991999308,234.83547907837664,492.0607421966711,30.078254475941222,746.3214954457011,209.93397521574894 +257,146.62592184663393,801.6445407544139,381.97669027099647,306.6991719629799,197.32794200220135,250.98728059336378,144.47440891527052,297.2735781373156,606.3384474098152,77.41460905844332,851.9906215193898,255.0713522503628 +258,258.31496915542306,847.680885506564,494.13798515346787,437.82567903042303,225.58448271930365,291.13743704178773,171.8077967760134,389.0922271244239,701.9686412729575,111.55011041113067,921.8924153674625,265.5547451543009 +259,104.65975510591943,801.4300265623181,359.4754394967165,251.21400967952513,143.02761118909848,254.68764756935644,170.07973246180057,273.410204719897,628.0219007654039,125.16652706762966,803.2652830588313,247.25335822626192 +260,102.39178199575983,837.6735911543468,385.98947879920763,295.1453343361537,141.10089341500742,275.0397452519934,204.08027646639235,256.16282223898327,628.7207827589283,148.97272800533875,791.0055024462525,233.08900580898447 +261,48.10939276502397,724.0126523124021,306.12688034074864,226.77314943018087,98.91388214221428,228.49051624511637,173.08087271462222,158.10935517333138,544.8843395610036,129.01952294398268,703.7655936354104,185.2491120977313 +262,2.2898391621129406,674.7638573263065,222.36967159528749,148.9937832196856,48.80696493648682,187.7201562355801,147.9706967911697,76.96143989988317,372.00932451217244,88.40377767715174,516.1043597988997,112.71155008412548 +263,-18.364350564207,592.3029794869409,225.75585720469672,139.03628376595233,32.916748195928974,177.995858324255,149.2024859699763,41.046981193778024,309.6769776916011,98.40214877944425,481.50769494334025,62.31312504439335 +264,-2.3349717275965816,636.4615913202055,261.10980468808714,161.72594583117825,26.843172164716286,221.594743940654,184.73334451174523,43.51071599446903,366.28128812852617,118.6326705685197,635.0799131887931,52.301165800912145 +265,0.27170324299328286,645.1248756198934,247.38226792820694,141.05940192562423,2.721027750382774,197.47194234570728,175.83872050628239,33.02568435317659,318.8655041082335,135.48718891242896,620.7465384455902,13.727127397470746 +266,11.231374769154613,653.0854294537426,274.8707982811009,194.81089055509773,-0.598099394823393,195.29177692337447,230.72484558573973,45.37065161615169,325.7943754464442,166.03357389480527,622.0085978806994,9.319028456576056 +267,65.32161621308136,701.338141232516,336.39059626312434,262.8311676823619,5.8517867463280595,199.3168678857049,260.91251320488686,100.77789736930444,408.7661059654421,213.95648471578085,670.5429773885444,22.546718285896574 +268,76.9559512355084,744.5728654340115,339.08759890397505,277.30053198589616,5.630906790604257,185.5333435402326,286.12081756670375,125.86410468218642,430.0350719033027,249.89813384185433,764.6480870706611,44.82812531217722 +269,181.5350090707335,711.4159848850728,451.0670109481933,373.9911691725731,40.25438011164394,198.76367764242354,304.76656368896784,226.9773942644058,509.7311219907351,269.66031338515154,888.9001223643752,49.88505663974843 +270,301.9680528630329,764.8238495541746,575.9763066663673,481.45923586352797,92.40934647150421,211.44998623332543,323.702490320841,351.332702546194,567.5614176809777,274.6063425683556,967.3264359006101,52.40459184607026 +271,141.94851605084534,706.2834625541216,401.4366893649723,330.1787014098727,22.894062494319613,162.60357719065394,320.39183161565023,273.0824752385042,472.7379108514644,259.90449248935766,892.5185295034155,57.319760375426725 +272,126.83922513491898,723.4427128581231,407.1983365999291,365.8791737260165,41.960408668364394,154.20628484055226,344.83711082736914,248.32224712184168,439.66502082340725,226.0154292633386,905.8658479486849,57.11174105529659 +273,111.55113242441139,629.973747491731,351.8849032950423,370.9258412182389,26.760960727851597,105.85417055700928,316.7106100185783,225.58522244026145,378.36488376591194,174.7982628294066,886.3926680591599,39.76825530310171 +274,39.38351275730332,574.2708010596591,235.5336098891857,264.90279119197123,-9.33501928137602,32.61968907311834,260.6031251161901,163.33337701918117,196.87475258635544,109.35076058535299,694.0486287032234,0.045935275628238514 +275,20.024545187634587,506.79773611687824,225.28962587544322,265.77365659982996,-11.709725276565536,12.653103426467254,239.35694271651846,161.5581617776043,149.14022292297037,74.75393743038724,668.5164930176677,7.366259119969385 +276,27.674119530329342,557.9892414475883,247.8740188081457,298.3159512560481,0.41573422184526976,40.766669338482245,255.33546106381448,188.42654612797563,204.95046813905466,52.75508873831161,817.9062761630694,30.59637710576454 +277,33.17682892018344,579.593084091524,225.45001717297941,263.01353189040134,17.382474333979985,25.575866292761418,234.08021976821084,206.15002634661658,214.53179963498266,31.8528972956241,806.0773019778435,57.53104892666769 +278,45.07309347723905,606.4284620504341,252.64522113169173,296.4040451615164,34.84978319155934,30.94681124428908,232.32078523427978,232.66710567090166,253.19306313479098,19.194399178823552,769.9628687316983,95.80971571939457 diff --git a/python/cuml/test/ts_datasets/endog_hourly_earnings_by_industry_missing_exog.csv b/python/cuml/test/ts_datasets/endog_hourly_earnings_by_industry_missing_exog.csv new file mode 100644 index 0000000000..458080ffab --- /dev/null +++ b/python/cuml/test/ts_datasets/endog_hourly_earnings_by_industry_missing_exog.csv @@ -0,0 +1,124 @@ +,Forestry and Mining,Manufacturing,"Electricity, Gas, Water and Waste Services",Construction,Wholesale Trade,Retail Trade,Accommodation and Food Services,"Transport, Postal and Warehousing",Information Media and Telecommunications,Financial and Insurance Services,"Rental, Hiring and Real Estate Services","Professional, Scientific, Technical, Administrative and Support Services",Public Administration and Safety,Health Care and Social Assistance +0,13.65,12.11,13.65,11.38,13.44,9.5,9.71,12.35,17.14,13.83,12.61,14.79,15.19,13.68 +1,13.851591792200573,,,11.78678800254696,13.867070571786073,,,13.0909155671951,17.812435117728164,14.82753845996605,13.295474886760168,,15.970541016832259,14.403721755556253 +2,13.932186081902552,12.355023621386243,14.73108527800528,12.198193625189077,,10.140409682929818,10.624890778487318,13.635226975795222,18.36749014305581,15.400909822132215,13.913703912697493,,16.5780963891776,14.681660977669893 +3,14.270795333322575,12.389220656198583,15.037142876939978,12.529790889900646,,10.411711219132206,10.758341117348921,14.053573457078269,18.776523723975927,15.931418721560089,14.526344870973595,16.746351574028473,16.802179355432582, +4,14.456451844523633,,15.44068996229278,,15.358240546466158,10.897833721372711,11.264642978217099,14.367647763213672,19.131335743321515,,14.648792834011681,16.795357183854268,,14.389451990191542 +5,,,15.835638324049468,13.025904835817773,,10.814465936104252,,14.627823362307268,19.163520039723746,,13.845282074885532,15.854124630510897,15.740833420283229,13.78885365892225 +6,14.755192726717508,13.391837153802191,16.10703989305711,13.393799246358679,15.743149821626593,10.871785353725251,11.25806724331454,14.793080078737047,19.099177482803615,15.671361692069091,,,,13.304967855532414 +7,,13.460685177131175,16.28128400492666,,15.764061922522005,10.624436687056946,,14.393242984219937,18.76028382617809,15.403333480182564,13.184251121142342,15.340403812554172,15.631810694537654,13.72564647530384 +8,15.111423095329485,13.860052744866836,,13.504369773165179,15.647866101549015,10.559756132924411,10.781764606923604,,,15.117207774393604,13.130494698085222,15.486256968345458,,14.787248466064424 +9,15.509152289611192,14.06848872408872,16.461348841513022,13.312529586519855,15.451333166356427,,10.40638565085023,13.745025043328631,18.41955609445053,15.23175800991669,,16.012411559233954,,15.168791033576886 +10,15.99905474034535,13.865012060228567,16.587632105308824,13.292312577028328,15.386623659307197,,10.190870942631053,13.751970213173598,18.531171516576446,15.72297198478351,13.998908635452404,16.72975951816983,, +11,16.510549491152634,,16.467664550151046,13.07488546826914,,9.558430839241527,10.054838707843363,13.913547294520743,19.01756504479315,16.352285753089898,14.740697910529839,17.439374575451623,, +12,,14.130868119227609,,12.923728527569113,,9.335576201739169,10.019043516690937,14.218635909450446,19.396699451288576,17.27975613593806,,17.718110781473023,17.76047886322311,15.907114431636783 +13,,14.150731910715734,16.28633645384955,12.593797428271577,,9.148404113495584,,14.647586412250444,20.094244516699277,17.895337095821482,15.599364719121331,17.639608655548603,17.341301647078755,14.80755974830567 +14,,,16.106656700032303,,14.560129339332676,,10.411476175242253,15.050114363658523,20.369120706454208,18.396085468372462,15.52625269520014,17.18359370446417,16.73204748507204,14.299466177489194 +15,16.223870619910233,13.70934880732469,15.81362202578458,,,9.172505474264355,10.752094658709545,,20.693143663634118,18.385372675770192,,16.827440101500745,,14.36343693548252 +16,,13.801490498350368,,11.893342197371068,,,11.216826705552865,,20.567511512606718,18.082536889458,,16.051596779045568,, +17,15.471557194866138,13.668462742553087,15.46862785139451,11.74453937774285,14.710541524792731,9.780577148513332,,15.95041834511198,20.519649957730785,,,16.153367914663377,16.409359490870646,15.488010058690753 +18,15.375498390845612,13.26291061275795,,,14.845142084689208,,12.00148005213932,15.896990406354737,19.93629090947301,17.3297862648617,,,, +19,15.539709223208575,13.177760452229116,15.38397799817525,11.70900403358093,15.168771394772293,10.352467812568,12.244499204115552,15.790451785461007,19.5920308828934,17.087343960951127,14.196487328272848,17.012128821795294,17.619713024493418,16.27321801625822 +20,15.69446493790059,,15.429866861658358,11.938756591185607,,,12.314705122501344,,,,,17.749446096990667,17.937073947586256,16.22194775081549 +21,15.240135253277526,13.161378607244325,,12.11813041571459,15.819371568103742,10.995682668360207,12.258713127760142,15.229731329617602,19.02350540163146,17.681672934656827,15.411103727888927,,17.85853675262889,15.381718553472455 +22,15.197180666001987,12.896791867830677,,12.340908626863227,16.255649396257507,11.16628018649039,,,,17.99707961824398,16.170386608740145,18.63097893045504,17.50223160782623,14.34404055935509 +23,15.666147849912862,12.685747756627947,,12.587553918226797,,11.067057786079852,11.634538821359905,,,,16.55284117684216,18.37411463966689,16.773365163491086, +24,15.997664193841038,12.771520533403976,15.883838619523702,12.975801540819567,16.540876626842746,,11.426089469830575,14.221864563946376,19.28565698097457,19.228503059556587,16.545635697312466,,16.25419287668816,15.129004646018375 +25,,12.737231365318934,,,16.976759104893066,10.78712389990577,11.175162483958083,14.190942829585476,19.92891758074776,20.44726321869338,16.541098174285437,17.405716978386923,,15.702683341424121 +26,15.861219136218843,12.575778232239756,16.40821550833602,,16.998852838617662,10.609594306346992,10.980124469555697,14.522717276261574,20.43841609404956,,,,16.759573747052517,16.60969775357549 +27,15.624856020184954,,16.6097391443485,13.973413242676422,,10.223372576700045,,,21.0150430628985,20.857495389596775,,17.17308743127811,17.501152586860183,17.291765990953927 +28,16.004222665660773,12.751472033735043,,,17.088009131580808,9.970208810437839,10.98630789707783,,21.204875571165708,,15.345740540642215,17.604157861823353,18.139016848716075,17.21497951345778 +29,16.360242220272596,12.752746935952635,17.329568388187358,14.695218134492185,17.067929818699444,9.781482314984832,11.241424375054338,15.957546333166437,21.402172850119477,20.231005212809258,15.68148060711822,18.701898715684482,18.49614296407842, +30,16.21387127467223,,17.7655885372505,14.979521275547857,16.868629438983746,9.603184175410224,11.566447395816509,16.402285419067706,21.11057521387587,20.02375158869258,16.07398515418231,19.459458459965642,18.536355526373136,16.06531535352705 +31,,12.729248489477106,18.252748100480247,15.10550412117673,16.852515742373786,9.361553170729106,12.276183482162534,17.11698636747496,21.22976443478051,19.846439731689753,16.8288311525825,20.080059424126155,,15.80283440155813 +32,16.19789387429227,12.975890026762595,18.74322017599655,15.249647631981372,16.79528490276848,9.539894201922701,12.692064826506279,17.2060588352331,20.808403215301606,19.88639489521436,17.50995215650415,20.08764173033968,17.744790863904786,16.111077495046775 +33,16.22028069913035,,,,16.56362615534087,9.736965011833718,13.121395859905862,16.952863243408935,20.627530768721506,20.87621265038233,18.24114337855649,,17.163482508801106,16.705274097802207 +34,,13.085743367995606,20.032451257981293,15.217544584566355,16.442837726647383,,13.497251503849032,16.709041697982173,20.40092819833503,21.555942829389064,18.668094954557187,19.595137806854765,16.993194760471493, +35,16.640759916691465,13.277375264259197,20.246157973409787,,16.414567019393985,10.392240569631083,13.677464388481305,,20.52970926387039,21.979620413485033,18.505790106315708,18.99814989302875,,18.271807600264626 +36,,13.478296977492253,20.383672706388467,15.039944043769497,16.244876914423898,10.84507896967941,13.881282153129611,15.931779204620836,20.667357601690828,22.35831805478115,,18.561786750174598,18.21484793792997, +37,17.055223754580084,13.74657331513443,20.479546406199255,15.051099087200612,16.412810921411136,11.293840421450092,13.652643787749549,15.711393559341623,21.428873809116848,22.847021849424287,17.643462873903157,18.96791975085005,19.04812786325131,18.00770732065041 +38,16.954960380099607,,,14.921442524650805,16.599585349075184,11.557217680347001,13.369557753021928,15.672036252105396,22.08424121637337,23.331504308037175,17.258136951830803,,19.314931298471475, +39,17.250810487412092,14.085552569295736,21.31104859145886,,16.948480655945556,11.609406542761098,13.08968006026147,,,23.27431880554524,17.365548563119532,20.56164781571207,19.633937754783584,16.973161317306587 +40,17.113532526841887,14.38065279629912,,14.541340620808118,17.33544168698889,11.714447269537395,,15.797014076583102,23.113328415685963,22.6351979321252,17.53771817259869,21.19111657080765,19.313838972090643,16.88528819722389 +41,17.273816286639146,,21.714902738460655,,17.460333178748254,11.739372323095271,12.371354413201876,16.044942075863236,,22.418687760841166,18.256232357579826,,18.5463289641368,17.017614842438775 +42,16.582276000843713,14.696385609707297,20.99960455771748,14.764434058310693,18.088738382459926,11.805778740677587,,17.03032407732933,23.223057961742313,,19.092169433861056,21.850813926939836,18.443025234761727, +43,15.849444212061595,14.860365522450357,,14.753515993654123,18.124140139501215,11.339634486279655,11.900095717096024,17.355666629438762,23.121050905424408,21.86247728032324,19.224628796820753,20.794627447501078,18.514207040432677,18.70024313132518 +44,15.27576645047528,15.320616487461226,22.301589778217586,14.642907698316534,18.590289580195478,11.14934226197248,11.906918629013393,17.585475193197365,22.50357017216377,22.42019843886321,19.62091128255854,20.168473970476658,18.450511457715322,19.071563663555693 +45,15.571596782253444,,22.619624538415895,15.074473950648414,18.653550857164607,10.742295550354001,12.49354322395374,17.363276600339248,21.95825654811762,,,,,18.536309231362477 +46,16.66719427183882,15.557018303060953,,15.336213816763681,,10.460344483734895,12.908732192447417,17.914854246360115,21.932220235526916,24.089636631165103,19.862668150212503,20.054846807713446,20.447042618530524,17.983305272619187 +47,16.03272039346193,15.817175190964967,22.273948414324124,15.592385058917381,18.72217654961037,10.21472132416064,13.372529517434845,17.825060814769373,,24.945329184343414,19.263388873745605,20.234109866176293,20.943438419699625, +48,15.878237417743936,16.15174981724681,22.225679675496103,,18.950075302536465,10.121040863029524,13.984147430927006,,22.188881435943383,,18.968166540505802,21.094017204097774,,17.358204932136985 +49,16.31370778950919,16.188500374352987,22.36436079531547,15.918698766348298,19.331714722254855,10.144981964722703,,17.012734162244552,,25.567491883235835,18.998233646681324,22.131303670355962,20.37452548297793,17.529168222938615 +50,16.328994503024127,,23.047652281526343,16.412880267270662,19.327806096393772,10.209173886834998,14.591312440422534,16.870327141732965,,25.151801583424604,19.09358904416378,,19.977983810232317,18.626350977009476 +51,16.96386247091328,16.561604858171208,23.132632172670487,16.771185342545433,,,14.776927433631474,16.43491938997211,,25.089869767886448,,23.022361478867225,19.47282313748951,19.189676970412055 +52,17.657980873074603,16.815632128222063,23.355513859026686,16.80813728803338,19.148372638785563,,,16.962486162240566,24.52905656492544,24.742324051956086,20.119032954691022,22.87524670263121,19.235492914443842, +53,17.710926462124604,16.94723426746753,22.77139837434785,17.12896528221741,,,14.625830443982585,,,24.279144933608002,20.468818172001505,21.93790165825381,19.92122724526191, +54,17.932187792346262,16.766482008417448,23.824071033742477,17.82052276434323,18.58129773033712,11.63760657149422,14.40468705652544,,,24.07247939265653,,21.746791242362942,21.212176895099244,19.466632935626034 +55,,17.02394811517415,,18.152104823892675,18.063484739491194,11.95510720877985,,17.633787064828706,25.454766883472814,,,21.976229288214732,,18.996601633753862 +56,18.407202393137414,16.900694520385542,,17.886113780071355,18.009253413403545,12.25533949944767,13.743415569336863,17.66040992522217,24.439442219249543,,21.824381402973884,,21.721663876203685,18.580126201927815 +57,18.499541863965845,,23.83817914168924,18.168531568801686,18.294305807028138,12.602880581672464,13.74624384020795,18.20365126888522,24.247790369878835,,,22.73596620166741,21.528403190557437,18.545476636547296 +58,,17.07856825710435,23.90551254023352,,18.357261796763886,,13.713296024851553,18.96843491960251,23.980921704976623,,21.41816837722393,23.533797431281393,21.50513998823391,19.296998224925463 +59,,,24.285649518273416,18.24168031770849,18.84888909256807,12.821410512218494,13.727997291703362,19.189288198415518,23.986533759970595,28.38561159794283,,24.15126644177114,21.46975971522927, +60,19.726087035025795,,23.703435152898418,18.05333043770335,,,,19.21337791343667,,28.700985722230676,20.338165948464855,24.091632566776166,20.612878227758525,21.005883256509247 +61,,,24.69261225651034,,,,,19.267047638632196,24.83045708710418,29.75143488758314,21.11997307119624,,21.187051145316737,21.41773919396587 +62,19.995929881514716,16.699060279979033,25.146070831147174,17.958278166068904,19.61914897599683,12.459004517783642,14.536940189116962,,25.532556260362355,,21.662963710049823,24.58750520343232,21.75210246530295,21.38334230745425 +63,20.497591560370303,16.734138198690765,25.406131521664935,17.779170374704705,19.898095690664995,12.051230405436923,14.839550731541058,19.129258148569917,26.11357794050113,28.093278515131207,21.63315913899587,23.256151004622385,,20.78545252821467 +64,,16.69978248997389,26.454851767393652,17.802675454050917,,11.670588179026606,15.52155059708487,18.779012816347073,26.574039745175032,28.251295337050703,,23.209282362887212,23.404269831445287,20.455057052783133 +65,21.059414399081263,16.72934178493705,,,21.056313602827252,11.398151397735806,,19.38867664197427,27.45949306660313,,23.587902255017724,23.64316988242037,23.519607296058776,20.666924131697026 +66,21.37759415006057,16.856076792146453,27.217082392695964,18.01756151493173,21.70599783573949,11.689148490020484,16.282929784889973,18.49321510028085,27.928544082288525,,,,23.845291869964207,21.25459275093894 +67,,16.883087069141187,27.8762182357895,18.00529742778539,22.23299901700501,11.60869114551357,,19.054766171232604,,28.384338539067027,,,23.575687095029227,22.26407886180237 +68,,16.983241413053133,28.555837628541866,17.876717452006172,22.20521772862452,11.608729407788429,16.275452439635796,20.16624148187153,27.7237340665492,29.44521210890627,23.4972973661789,,, +69,24.849767677596628,17.159113433608127,28.15120495097779,,22.578127918108695,,16.43932385719536,19.862737642450988,,30.652999349279003,23.336847344443964,26.71285841578283,,23.540006985792726 +70,,,,18.64125702500245,22.765181197267374,12.285362319359294,,,27.642487869180943,,23.491747387643667,27.65106758298586,,24.11351491428843 +71,24.133813704441792,17.196490147449122,,19.06526123383819,,12.685826767352129,16.037145783270336,,27.424661351589144,32.48035488878264,23.3168147760208,27.332019035177307,23.74313316948091, +72,24.21027090371882,17.52118670002999,28.715667461627902,19.644663870283804,,,,22.121161867346384,27.23491084348189,32.685301217174995,,26.75566112769298,24.37939195845465, +73,23.75879803605197,17.647912965565116,28.625634325067445,,22.294199070305364,,15.860994618294951,22.055020395192106,27.587982388194263,32.991220230046096,24.489477039785008,26.539507902546546,25.46257567120528,23.127909220661756 +74,23.67882788678878,17.657073756617397,29.029928688150573,20.60193858517074,22.46739059862349,14.063268508860913,15.618677707041277,22.681631186993123,28.10917072181593,31.75923178044594,25.23726046557517,26.23400813803925,26.22726838707944, +75,24.159878143305637,18.078570054672575,29.502797928780684,,,14.476102595067202,,,,31.745879207097587,25.603334526919085,25.9663258976724,26.81195494657682,24.541977453257402 +76,24.65155621411613,18.41180123388161,29.939742245325682,21.54306449988818,22.354332993435467,14.489940610178635,16.14634908245693,23.152153883330744,,30.859906236562804,25.93216350649512,26.415852312334756,26.46617142958987,25.814472415749652 +77,25.013563151515392,,29.5172953810659,22.392398660764044,22.304190746784766,14.88317339664858,16.877370427684426,22.83580350735958,30.879712185316443,,26.35204794583786,27.655905628652068,25.753691781736272, +78,,18.698649590352154,29.562764144467955,22.827495246529704,22.60401292160853,,17.179600275628168,23.247840360461893,31.618542557952846,32.77535782975501,26.293184753582853,,,27.050939374333044 +79,26.197853010585654,,29.793937109789145,23.104942296080896,23.171993323956755,14.744140776130486,17.94565588869854,22.552668013473806,31.7214022064877,34.00573324867393,25.853145405702424,29.468730333670663,25.60097934425267,26.96073171152589 +80,27.380028290923008,19.183563670995653,30.638774111200075,23.543668714122607,,14.47858561599292,18.342814184753543,22.994223793266446,31.855879490396024,34.82580040793662,25.77833840788696,30.10374383269659,,26.391829021531287 +81,27.81231823069846,19.760444473509402,30.385088909314987,23.355632618252535,,,18.702914332613528,22.83324849043029,,35.78064636746854,25.471904866097244,29.73637863009134,27.462944972445285,26.0764944965911 +82,,20.04640821235519,,,24.059233349837488,,19.101540068463787,24.03116568971664,30.84816281935988,36.54689818504657,26.483235316951447,29.448914770848795,,25.74921701018323 +83,,20.008283569740986,30.190825752875785,,,,19.755032047285805,,31.07723491431308,36.487960412213994,26.590236329204323,28.8278656081488,29.014646936902277,26.67331560890523 +84,25.502302454602887,20.562766465667032,30.452444265676892,23.744915275100784,23.707438743003053,13.137509210859115,19.264223133213886,24.320519876037537,30.63628309849687,36.0254555236954,26.594034950316818,28.257519994211492,28.518395317470052,27.452327325067984 +85,24.98794903469184,20.86649426043948,30.189447275227156,23.702779333149962,24.38083938962448,13.342681283865911,19.314306908601367,,29.899727722416877,,27.84098431187967,28.38623504666513,28.456597635889285,27.761532246733413 +86,24.845617395496213,21.08612127230849,30.14477816891343,,24.622166422847084,13.468559682627623,18.83086016132365,,,34.7520045993564,28.783803339705965,29.327774660718184,28.095856393360037, +87,24.975947903258024,21.43839391426747,30.689852322332705,23.40522989047162,25.02821476658986,13.77177566394805,18.59264731211536,25.6004725031517,31.12801786780056,,28.596999421154976,30.225696146703147,27.574605637577918,28.113191879854224 +88,,21.790223762602924,31.394500672733532,,24.890717225469206,14.144732766007824,18.582342804648768,25.189615494007423,32.04254488801376,34.05001498159384,,30.703074304091057,27.50334645443357,27.906752452872883 +89,25.54751628074122,22.078756916232816,,,25.367976238958263,14.557740459851596,,25.020401897240475,31.862737859584932,35.20417548622854,28.451159841496754,30.583205519360252,28.49564501814814,27.404560966229297 +90,26.560372285078834,22.441438089172223,,23.38738938385884,,15.032614261837429,17.917282088007717,,32.511765645690375,,29.09612815401034,31.620221493503493,29.23645536381405,26.808448117928055 +91,,22.556067996747686,31.271012751790312,23.315859477353488,25.068411126946085,15.417305776945842,18.08450806077097,24.892374975737827,33.73246257269199,36.95764327124107,28.256297557749672,31.418579060919292,30.128555708622592,27.488432436281695 +92,,,34.572113419973014,22.847111578729052,24.767010726891833,15.811001259018141,,25.116813865965792,33.97866962705299,38.214626749107296,28.431151306018794,31.53716422390082,30.38174281304598,28.36685077524534 +93,27.117983350005378,22.974444126582963,35.340538848156605,23.41097756623431,24.729836576882796,16.20907046417267,,24.626896136404437,32.61408744294142,,29.94675106935023,29.88041496850416,,29.37686520065353 +94,26.13006272750542,23.42596988081832,33.778810336863174,23.294018041509705,25.06999976730927,,19.575180659076135,24.625512889143184,,39.98674412640572,31.960126131890853,30.9445532093056,29.66697037433533,29.288583947087055 +95,28.151894557143997,23.25505330056562,34.27905665525281,23.299114423153114,,16.314662097420715,20.019266239129706,25.031999034138526,32.93860466764186,,,,,29.332919796443235 +96,,23.701821683344587,34.833327842258065,23.63528979406735,24.5470457509538,,20.646561737159086,25.85629253863722,33.340664934576346,39.31618581303483,31.08799702043318,,, +97,27.298801802302354,23.4669034742969,35.11391214644418,24.22776964143364,24.52538579905033,16.14614409900249,20.84204222104956,26.460937836209936,33.080820736420876,38.967613538827436,32.037502832611956,31.395424283185832,,28.088454804263375 +98,27.055841675015113,,,25.06827871279372,24.249381174013656,15.828090121758219,21.016757110709822,27.460534505046308,,38.652216044422666,32.11881214130925,33.19460011032051,30.097118050651403,27.667260704032692 +99,27.896543286045464,23.556931180523904,37.236187054490216,25.10555553116717,24.36399906409404,15.719154746661877,21.203979618222824,,,38.48639069824376,32.02861068270491,,31.309918439168708,28.23417965903108 +100,26.921816444485163,23.725458576402232,37.06629203555905,25.826052579145948,25.177004888380445,15.38141711511016,20.781847539869162,28.01590683292107,,39.23536793018039,,32.10322398031023,31.506223641014824,29.110517917882447 +101,28.082525282035398,23.899381909328028,37.72000077044671,26.25477861791052,24.858268635040684,15.175598590873083,20.882078477028287,28.445643913965725,,40.509223906146104,,31.596694745122885,31.378555962948795,29.514948435187144 +102,28.67947961888848,24.121414724100628,,,25.11977647489118,15.106125108464909,,28.69733233750169,37.377206246487134,41.151060024634106,31.98676954731085,32.240541897793385,31.460619844385484,30.304477227057085 +103,,23.55453760092264,39.23975388299498,26.995358412408784,25.67634661289073,15.41695287110958,20.61142160105469,28.52177590445198,37.93397576249284,42.340835704061114,,31.975214062536953,, +104,29.52504458850386,23.631931217994364,,27.278498952893074,25.72698543267635,15.418663551516804,20.203393124927736,28.727059951824913,37.14962382093216,,32.46801652545539,31.787805184535422,30.770603616736167,30.430609991411085 +105,29.104933673171587,23.96690492033853,,,,15.807181215662224,20.344049550443668,28.338977283447083,37.41371103499975,42.90190026324728,,31.887594908647223,30.51707674923643, +106,29.6336123870842,,,,27.219038966871167,,20.4594045758944,,,42.057978097779035,,33.16779409537264,30.838878982241567,29.06150063917811 +107,,23.943026150023712,39.55649611328729,27.803013447079607,27.207844281353324,16.339728454014644,20.74262717227577,,37.60923375450625,42.203898634977385,,34.25501856819074,31.378730345613043, +108,29.758968329738003,23.900762408746257,39.76126501927288,27.819852603987723,27.270198918778952,16.885084864367318,21.26196249232675,27.49708295657939,36.366245496535896,42.20013471137889,33.7719901280318,34.69039241904022,32.203329518494556,29.815075523280772 +109,29.24622668205186,23.689108246801936,39.71083467126169,28.287228937799178,27.798165631548986,17.38822886599891,21.892587702446274,27.761016261471127,36.39006828617381,41.921955305224294,33.97326285582644,,32.99367277343588, +110,28.953431104361282,,39.64283015402416,28.13148479382319,,17.668534355838588,22.362109742434207,28.201618684991946,,,33.40501431614065,34.67751042211091,33.795086413983356, +111,31.56062709826174,23.808681852470528,39.18517851055612,27.80226529139295,28.19341926763215,18.091782142069754,,28.91017851837316,37.31730370849863,40.350676565208616,32.87604085843692,33.75992786236827,33.90641010601702,31.06121007243653 +112,,24.24462419101302,39.08579271588977,27.811918637060838,28.381407148496546,18.434309177666876,22.781519242309443,,,40.750231326075166,32.37189877919565,,33.43757416356101, +113,30.754677187671135,23.35049937164877,38.542256211749944,27.96472837292261,28.54634640535743,18.475877813089294,23.599017322044045,28.994150825960794,38.78143388613222,41.207956265328946,32.85659088136453,32.98881088350147,32.71319444813519,30.9083153188646 +114,29.451126287462294,23.917629634237787,38.97152057866358,28.136025295151047,28.336692993176232,18.41091715811818,23.518377283249407,29.42791972093787,39.23925677537909,41.36156898366093,34.4992487277059,33.80780913929345,32.468713081836796,30.791005294663236 +115,30.316761013992284,23.966859875322662,40.02962839838097,27.941234259343467,28.515779467449228,18.257986032035472,23.620032883947097,30.170534098666323,39.35029132076366,42.19395930363636,35.93094272320709,34.3195832389059,33.04204108839906,30.699865611609432 +116,30.721142969265745,24.218538673573697,40.33147235520991,28.18491513770975,28.95973075651526,18.237522621892214,23.489315933078423,30.524783296117274,39.20198382646372,44.646579557148975,36.672734356542584,34.67759486186219,33.48047774628434,31.539809645187752 +117,30.613746355869168,24.022510461463455,39.23060018013212,28.07985755673791,27.92513049951384,18.178153713095085,23.68403573498857,29.892952307239767,38.680586480268516,45.45620837168165,37.691677857875035,35.05488204685573,34.12668393647759,32.030759425418246 +118,30.363963230564828,24.668119018978047,39.68907321396445,28.42628570016464,28.176646409073587,17.98201101498513,23.578509831621627,30.044375769470324,38.65186834376624,44.26129493191529,36.991714995045896,36.45477828185509,34.73090012617649,33.04598348088618 +119,30.11110957865866,24.734222210903923,40.087384228205536,28.621222618750576,28.434826778009704,17.819622498356807,23.36498845698211,29.99561515151064,38.21725453610195,44.947163941983646,36.75469741619022,37.34151110042108,35.311334273448985,33.907428220081655 +120,30.974432148449008,25.42921763672246,41.25443777979108,28.218053570999604,28.834264961524852,17.94499819361194,23.273826593944108,29.85521492121272,37.8263388737728,45.19559657022285,36.06371296657798,37.479652816971694,35.13198987142161,34.373956127550805 +121,30.27311597746858,25.481078617071045,42.21759388295207,28.666315544181327,28.872295796251493,17.961506319001167,24.02395524954119,29.976072828888345,37.860051482131524,45.63186048672116,36.024060713808964,36.80817746477709,35.26781460544899,34.190745522952575 +122,31.8062925353201,25.97739970859231,41.662773257116314,29.301726042220057,29.488340401140576,18.349039777267706,23.965156324851055,29.708934484805052,37.81105088311454,44.098913515609055,36.172013940114326,36.2715477099431,35.20593550433007,33.72041309158106 diff --git a/python/cuml/test/ts_datasets/exog_deaths_by_region_exog.csv b/python/cuml/test/ts_datasets/exog_deaths_by_region_exog.csv new file mode 100644 index 0000000000..35824ca89b --- /dev/null +++ b/python/cuml/test/ts_datasets/exog_deaths_by_region_exog.csv @@ -0,0 +1,29 @@ +,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 +0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0,0.0,0.0 +1,0.19875209056106533,-0.03703668679562561,0.295948341698736,0.037036756290946694,0.38958056246995176,-0.03703686180403085,0.47942543957854117,0.037036891815146945,0.5654604994420314,0.03703385723138967,0.650319402879635,-0.03703686888160362,0.7176549202255535,-0.03703588287852368,0.7937041483849424,-0.03703588151587406,0.8414708110015084,-0.03703672098440032,0.8912071759819662,-0.037036644135165414,0.9322228269534192,0.03703667738402852,0.9649541407882584,0.037036840568423934,0.9907454886548319,-0.03703685981409009,0.9975729368463327,-0.037036427545358824,0.9999899958342958,-0.03703650200042182,0.9918603059772687,0.037036115422779625 +2,0.38958056246995176,-0.07407337359125123,0.5654604994420314,0.07407351258189339,0.7176549202255535,-0.0740737236080617,0.8414708110015084,0.07407378363029389,0.9333893780282453,0.07406771446277934,0.9947834290640049,-0.07407373776320723,0.9999899958342958,-0.07407176575704735,0.9867488209877175,-0.07407176303174812,0.9092972390096973,-0.07407344196880064,0.8084962368241235,-0.07407328827033083,0.6755963404934754,0.07407335476805704,0.5162482046746085,0.07407368113684787,0.3367883600959299,-0.07407371962818018,0.14113103602388338,-0.07407285509071765,-0.05839846036885846,-0.07407300400084364,-0.2555914790708911,0.07407223084555925 +3,0.5648776867457179,-0.11111006038687685,0.7844617548535878,0.11111026887284009,0.9324273458780673,-0.11111058541209257,0.9974947805708058,0.11111067544544083,0.9752584930881769,0.11110157169416901,0.8713852646397194,-0.11111060664481084,0.6757445585300422,-0.11110764863557103,0.4330416582256388,-0.11110764454762218,0.14111997891143627,-0.11111016295320096,-0.15774566156077083,-0.11110993240549623,-0.44260768120567157,0.11111003215208558,-0.6887625607282502,0.1111105217052718,-0.8762595779997068,-0.11111057944227028,-0.977606507732746,-0.11110928263607647,-0.996579581542493,-0.11110950600126546,-0.9259971962430006,0.11110834626833888 +4,0.7176549202255535,-0.14814674718250245,0.9333893780282453,0.14814702516378678,0.9999899958342958,-0.1481474472161234,0.9092972390096973,0.14814756726058778,0.676441758149388,0.14813542892555867,0.338160994485027,-0.14814747552641447,-0.05839846036885846,-0.1481435315140947,-0.44838279812868687,-0.14814352606349623,-0.7568023389898726,-0.14814688393760128,-0.9516018773354787,-0.14814657654066166,-0.9963609914451044,0.14814670953611409,-0.8847345610651648,0.14814736227369574,-0.634659033919092,-0.14814743925636037,-0.2794373334021834,-0.1481457101814353,0.11659775580137664,-0.14814600800168728,0.49421076000889724,0.1481444616911185 +5,0.8418215167270572,-0.1851834339781281,0.9989401079316864,0.1851837814547335,0.9096762132340875,-0.1851843090201543,0.5984720204891397,0.18518445907573472,0.1413244557374412,0.18516928615694836,-0.354105675287268,-0.1851843444080181,-0.7571177568390984,-0.18517941439261837,-0.9904800920614084,-0.1851794075793703,-0.9589240765966958,-0.18518360492200162,-0.7055401798405714,-0.18518322067582707,-0.27947058181072887,0.1851833869201426,0.21543164327100997,0.18518420284211964,0.6605172125775963,-0.18518429907045045,0.9380732777200355,-0.18518213772679412,0.989770384000888,-0.18518251000210909,0.7986445253454938,0.18518057711389813 +6,0.9324273458780673,-0.2222201207737537,0.9752584930881769,0.22222053774568018,0.6757445585300422,-0.22222117082418513,0.14111997891143627,0.22222135089088166,-0.44316154499371474,0.22220314338833802,-0.8798309129181567,-0.2222212132896217,-0.996579581542493,-0.22221529727114206,-0.783001799069347,-0.22221528909524435,-0.27941544048547,-0.22222032590640192,0.31154129916430334,-0.22221986481099246,0.793824326510654,0.22222006430417116,0.9999899855124953,0.2222210434105436,0.8591914808355817,-0.22222115888454055,0.41215069065251386,-0.22221856527215295,-0.17439940060295403,-0.2222190120025309,-0.7000126599889558,0.22221669253667775 +7,0.9858602391935614,-0.2592568075693793,0.8644599416218897,0.2592572940366269,0.3351276963093557,-0.259258032628216,-0.3507831552351123,0.2592582427060286,-0.8728384682208625,0.2592370006197277,-0.9917579244663942,-0.2592580821712253,-0.6315296048789927,-0.25925118014966575,0.01703664511066828,-0.2592511706111184,0.6569864630177729,-0.25925704689080226,0.9881680297701986,-0.25925650894615787,0.8547673826174453,0.2592567416881997,0.3195606562508116,0.25925788397896754,-0.3684485701738951,-0.2592580186986307,-0.8797645046756954,-0.25925499281751174,-0.9795856256264135,-0.25925551400295277,-0.6182589709920092,0.2592528079594574 +8,0.9999899958342958,-0.2962934943650049,0.676441758149388,0.29629405032757355,-0.05839846036885846,-0.2962948944322468,-0.7568023389898726,0.29629513452117556,-0.9976078027780596,0.29627085785111734,-0.6372456874930467,-0.29629495105282894,0.11659775580137664,-0.2962870630281894,0.8041820959227095,-0.29628705212699247,0.9893580422707821,-0.29629376787520256,0.5849170720767289,-0.2962931530813233,-0.17436114778033118,0.29629341907222817,-0.8290257830693025,0.2962947245473915,-0.9844397823918536,-0.29629487851272074,-0.5366148490271767,-0.2962914203628706,0.2316062652459557,-0.29629201600337457,0.85933118891123,0.296288923382237 +9,0.9742533069920303,-0.33333018116063057,0.4279990470348874,0.3333308066185202,-0.44270478422051296,-0.3333317562362777,-0.9775299157556051,0.3333320263363225,-0.7738840304681833,0.33330471508250703,0.016973153546815833,-0.33333181993443256,0.793998482402275,-0.3333229459067131,0.9827385691500352,-0.3333229336428665,0.41211840011841105,-0.3333304888596029,-0.4575357992709801,-0.33332979721648875,-0.9811296105727584,0.3333300964562567,-0.763087507992051,0.3333315651158154,0.033803735832937575,-0.33333173832681084,0.803847239088657,-0.33332784790822945,0.9660600410120782,-0.33332851800379637,0.39681878605673,0.3333250388050166 +10,0.9096762132340875,-0.3703668679562562,0.1413244557374412,0.370367562909467,-0.7571177568390984,-0.3703686180403086,-0.9589240765966958,0.37036891815146944,-0.27982030153241716,0.3703385723138967,0.6632092552607707,-0.3703686888160362,0.989770384000888,-0.37035882878523674,0.41757808565271376,-0.3703588151587406,-0.5440209985214498,-0.37036720984400323,-0.9999900000020655,-0.37036644135165414,-0.5366786972951043,0.3703667738402852,0.4207757540174443,0.3703684056842393,0.9959308311896113,-0.3703685981409009,0.6503386575500577,-0.37036427545358824,-0.28802324866978724,-0.37036502000421817,-0.9615870205441172,0.37036115422779625 +11,0.8088331995037931,-0.4074035547518818,-0.157974228291299,0.4074043192004137,-0.9519984831623965,-0.4074054798443394,-0.7055401798405714,0.40740580996661635,0.31199270920924493,0.40737242954528635,0.9975276812951132,-0.40740555769763975,0.5851608520908025,-0.40739471166376046,-0.46359716800381917,-0.40739469667461464,-0.9999900000020655,-0.4074039308284035,-0.4496473716596205,-0.4074030854868196,0.5921902353181169,0.40740345122431376,0.9882015506186382,0.4074052462526633,0.30474729997418254,-0.40740545795499106,-0.7118409655458136,-0.40740070299894704,-0.9492397584664318,-0.40740152200464,-0.14902839929043718,0.4073972696505759 +12,0.6757445585300422,-0.4444402415475074,-0.44316154499371474,0.44444107549136036,-0.996579581542493,-0.44444234164837026,-0.27941544048547,0.4444427017817633,0.7948176905375149,0.44440628677667604,0.8626932520163858,-0.4444424265792434,-0.17439940060295403,-0.4444305945422841,-0.9939313274391613,-0.4444305781904887,-0.5365728071709441,-0.44444065181280384,0.5920733924140472,-0.4444397296219849,0.965848145080876,0.4444401286083423,0.10790976045352942,0.4444420868210872,-0.8923367754232626,-0.4444423177690811,-0.751045933420351,-0.4444371305443059,0.3434579436630076,-0.4444380240050618,0.9999899980286128,0.4444333850733555 +13,0.515716114442965,-0.48147692834313305,-0.6887625607282502,0.481477831782307,-0.8838226767946192,-0.48147920345240114,0.2151199436546398,0.4814795935969103,0.9999899855124953,0.48144014400806573,0.3221207063604493,-0.48147929546084706,-0.8281713171032081,-0.4814664774208078,-0.7720780738215345,-0.48146645970636276,0.4201669500408618,-0.4814773727972042,0.986771760456212,-0.48147637375715036,0.10777489471581637,0.4814768059923708,-0.9304700815816802,0.48147892738951115,-0.6080831644211951,-0.4814791775831712,0.6055871902376866,-0.48147355808966474,0.9291821426953464,-0.4814745260054837,-0.10865801190272363,0.4814695004961351 +14,0.3351276963093557,-0.5185136151387586,-0.8728384682208625,0.5185145880732538,-0.6315296048789927,-0.518516065256432,0.6569864630177729,0.5185164854120572,0.855837008655436,0.5184740012394554,-0.3699502407716862,-0.5185161643424506,-0.9795856256264135,-0.5185023602993315,0.03406847349780193,-0.5185023412222368,0.9906071510842663,-0.5185140937816045,0.30311829413640534,-0.5185130178923157,-0.8877420074045105,0.5185134833763994,-0.6057090742401774,0.5185157679579351,0.6856284592184476,-0.5185160373972614,0.8367210198267926,-0.5185099856350235,-0.39772129305819093,-0.5185110280059055,-0.971990024839086,0.5185056159189148 +15,0.14117879448049286,-0.5555503019343843,-0.9789463148795308,0.5555513443642005,-0.2795318944295686,-0.5555529270604628,0.9379997830302238,0.5555533772272041,0.412715542089693,0.555507858470845,-0.888027809036927,-0.5555530332240542,-0.5367964384047176,-0.5555382431778552,0.8144326792815318,-0.5555382227381108,0.6502877058397336,-0.5555508147660048,-0.7117851953493901,-0.5555496620274812,-0.7511352954312417,0.5555501607604278,0.6064171458886196,0.555552608526359,0.841151785030102,-0.5555528972113514,-0.4872125831994987,-0.5555464131803823,-0.9059555991626467,-0.5555475300063273,0.35912913832409527,0.5555417313416944 +16,-0.05839846036885846,-0.5925869887300098,-0.9976078027780596,0.5925881006551471,0.11659775580137664,-0.5925897888644936,0.9893580422707821,0.5925902690423511,-0.1745793372286799,0.5925417157022347,-0.9884520228948253,-0.5925899021056579,0.2316062652459557,-0.5925741260563788,0.9784504703557679,-0.5925741042539849,-0.287903257198445,-0.5925875357504051,-0.9488443019336659,-0.5925863061626466,0.3433826094832344,0.5925868381444563,0.930140826606319,0.592589449094783,-0.39969212792443704,-0.5925897570254415,-0.9056491293324435,-0.5925828407257412,0.45062823450164763,-0.5925840320067491,0.8794464004107753,0.592577846764474 +17,-0.2556475528903014,-0.6296236755256354,-0.9271559567810723,0.6296248569460938,0.49431918414341114,-0.6296266506685245,0.7984869476954495,0.629627160857498,-0.7008886313740063,0.6295755729336244,-0.6239918053963832,-0.6296267709872615,0.8595197162521094,-0.6296100089349025,0.40199645238299897,-0.6296099857698589,-0.9613972933022695,-0.6296242567348055,-0.14899899503835154,-0.629622950297812,0.9999899980286128,0.629623515528485,-0.10879398274242924,0.6296262896632069,-0.9770208430761707,-0.6296266168395317,0.3590864129959616,-0.6296192682711,0.8796393407968044,-0.629620534007171,-0.5857527917659857,0.6296139621872536 +18,-0.44270478422051296,-0.6666603623212611,-0.7738840304681833,0.6666616132370404,0.793998482402275,-0.6666635124725554,0.41211840011841105,0.666664052672645,-0.9823573618878574,0.6666094301650141,0.03394150832089122,-0.6666636398688651,0.9660600410120782,-0.6666458918134262,-0.4786804663343352,-0.666645867285733,-0.750987091654764,-0.6666609777192058,0.8136735694422572,-0.6666595944329775,0.38132565088632636,0.6666601929125134,-0.988345352496325,0.6666631302316308,0.06756924542115483,-0.6666634766536217,0.9564506653568983,-0.6666556958164589,-0.5019983315990214,-0.6666570360075927,-0.7285043559362132,0.6666500776100333 +19,-0.6121127728395851,-0.7036970491168867,-0.5514833485346234,0.7036983695279873,0.9683228787330317,-0.7037003742765862,-0.07515110493929571,0.703700944487792,-0.9206604032955211,0.7036432873964037,0.6759116003241548,-0.7037005087504686,0.4866013081585165,-0.7036817746919498,-0.9971015513627467,-0.703681748801607,0.14987717870571562,-0.7036976987036061,0.8871573454493772,-0.7036962385681428,-0.723637384869562,0.7036968702965419,-0.4199684653962168,0.7036999708000546,0.9999899462604889,-0.7037003364677118,-0.22377312579484251,-0.7036921233618176,-0.8503231178393341,-0.7036935380080146,0.773480342431362,0.7036861930328129 +20,-0.7571177568390984,-0.7407337359125123,-0.27982030153241716,0.740735125818934,0.989770384000888,-0.7407372360806171,-0.5440209985214498,0.7407378363029389,-0.5373502782659456,0.7406771446277934,0.9999899052848786,-0.7407373776320724,-0.28802324866978724,-0.7407176575704735,-0.7609360610761183,-0.7407176303174812,0.9129450621581827,-0.7407344196880065,-0.008851307462160092,-0.7407328827033083,-0.9057568865597412,0.7407335476805704,0.7636632073908185,0.7407368113684786,0.27236162276856596,-0.7407371962818018,-0.9881088348112222,-0.7407285509071765,0.5516563892843228,-0.7407300400084363,0.5291869888000177,0.7407223084555925 +21,-0.87193884509668,-0.7777704227081379,0.016838259630527815,0.7777718821098806,0.8549549087119184,-0.777774097884648,-0.8796955782699294,0.7777747281180858,0.033671758626498904,0.7777110018591831,0.8537573325168359,-0.777774246513676,-0.8879367677509318,-0.7777535404489972,0.051090669799684324,-0.7777535118333552,0.836655465724281,-0.7777711406724068,-0.8951871829181393,-0.7777695268384737,0.06722132183841885,0.777770225064599,0.8285264922776944,0.7777736519369026,-0.9074048925451954,-0.7777740560958919,0.08398101795948178,-0.7777649784525352,0.8181069117562207,-0.7777665420088582,-0.9098460023507845,0.7777584238783721 +22,-0.9519984831623965,-0.8148071095037636,0.31199270920924493,0.8148086384008274,0.5851608520908025,-0.8148109596886788,-0.9999900000020655,0.8148116199332327,0.5929312814881291,0.8147448590905727,0.3059893459398148,-0.8148111153952795,-0.9492397584664318,-0.8147894233275209,0.8244530003423381,-0.8147893933492293,-0.008851307462160092,-0.814807861656807,-0.8032555607809535,-0.8148061709736392,0.9544732210283877,0.8148069024486275,-0.3204034750533767,0.8148104925053266,-0.5808196568477796,-0.8148109159099821,0.9999899992185399,-0.8148014059978941,-0.5994330513136535,-0.81480304400928,-0.2947296926806361,0.8147945393011518 +23,-0.9941049459093658,-0.8518437962993892,0.5792777792669651,0.851845394691774,0.2229827634447428,-0.8518478214927097,-0.8754519938632033,0.8518485117483797,0.9450628489858791,0.8517787163219624,-0.38569021124812053,-0.8518479842768831,-0.4347466492543621,-0.8518253062060446,0.9738857369672237,-0.8518252748651033,-0.846220229387787,-0.8518445826412073,0.1664799691506045,-0.8518428151088044,0.6245002243215082,0.851843579832656,-0.999941600805745,0.8518473330737505,0.7099643773157669,-0.8518477757240721,0.05749197052133409,-0.8518378335432529,-0.7831005942562982,-0.8518395460097018,0.9857945987654071,0.8518306547239314 +24,-0.996579581542493,-0.8888804830950148,0.7948176905375149,0.8888821509827207,-0.17439940060295403,-0.8888846832967405,-0.5365728071709441,0.8888854035635266,0.9670567737039758,0.8888125735533521,-0.8959736355102713,-0.8888848531584868,0.3434579436630076,-0.8888611890845682,0.3863011637685579,-0.8888611563809774,-0.9055781749588148,-0.8888813036256077,0.95428489738448,-0.8888794592439698,-0.5018882231177466,0.8888802572166846,-0.21456293876368854,0.8888841736421744,0.8221608903943978,-0.8888846355381622,-0.9918563569924574,-0.8888742610886118,0.6451653778459284,-0.8888760480101237,0.040701279566260545,0.888866770146711 +25,-0.9593237341482497,-0.9259171698906403,0.9393589046790989,0.9259189072736673,-0.5442477339904711,-0.9259215451007714,-0.06632188365236895,0.9259222953786735,0.6512299449692022,0.9258464307847417,-0.984866659017702,-0.9259217220400904,0.9133255567115579,-0.9258970719630919,-0.4936284286610337,-0.9258970378968514,-0.13235172276043156,-0.925918024610008,0.699239887226607,-0.9259161033791353,-0.9882264033759778,0.9259169346007131,0.885150931234674,0.9259210142105984,-0.43048370222645677,-0.9259214953522524,-0.1978142568212713,-0.9259106886339706,0.7454235525794333,-0.9259125500105455,-0.9962828703311972,0.9259028855694906 +26,-0.8838226767946192,-0.9629538566862661,0.9999899855124953,0.962955663564614,-0.8281713171032081,-0.9629584069048023,0.4201669500408618,0.9629591871938206,0.10790976045352942,0.9628802880161315,-0.610561503823061,-0.9629585909216941,0.9291821426953464,-0.9629329548416156,-0.9999898675235251,-0.9629329194127255,0.7625582929726507,-0.9629547455944084,-0.3199398958003875,-0.9629527475143007,-0.2142947777660121,0.9629536119847416,0.6881166132861658,0.9629578547790223,-0.9684970602594969,-0.9629583551663424,0.9638707030374314,-0.9629471161793295,-0.6886974011408288,-0.9629490520109674,0.2160298456399789,0.9629390009922703 +27,-0.7730863983809267,-0.9999905434818916,0.9712949391614468,0.9999924198555608,-0.9813448590810474,-0.999995268708833,0.8037842605294165,0.9999960790089675,-0.4731064079718626,0.9999141452475211,0.05090026690622852,-0.9999954598032976,0.3814093093312873,-0.9999688377201392,-0.749578910983535,-0.9999688009285995,0.956375730864423,-0.9999914665788087,-0.989486878875324,-0.9999893916494662,0.8329236544412376,0.9999902893687701,-0.5170101552123412,0.9999946953474461,0.10125834574727577,-0.9999952149804325,0.33417728942596997,-0.9999835437246882,-0.7052042823325289,-0.9999855540113892,0.9406143579027518,0.9999751164150499 diff --git a/python/cuml/test/ts_datasets/exog_guest_nights_by_region_missing_exog.csv b/python/cuml/test/ts_datasets/exog_guest_nights_by_region_missing_exog.csv new file mode 100644 index 0000000000..33411f7bd4 --- /dev/null +++ b/python/cuml/test/ts_datasets/exog_guest_nights_by_region_missing_exog.csv @@ -0,0 +1,280 @@ +,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23 +0,0.0,-0.02746317167259738,0.0,-0.007057893525316508,0.0,-0.014159119369999468,0.0,-0.08749370035088656,0.0,-0.04865625779594764,0.0,-0.008568659351486464,0.0,-0.010099001044268438,0.0,0.06541195138945713,0.0,0.01367535330005259,0.0,0.00702884733073994,0.0,0.027210792607960475,0.0,0.2436440609625011 +1,0.06661711827247965,-0.05839362741669856,0.0832366491561994,0.030356773267918424,0.0998330429860223,-0.022676823066760606,0.11640181278434966,-0.08353910219117414,0.1329382701053118,0.1146733584393578,0.14943757315012937,0.018127975473701558,0.16589609842745917,-0.025668754114287165,0.18230801406500982,0.14878472563092088,0.19866928975977965,-0.025979754882546776,0.2152732472099785,-0.05469365918347875,0.23122030493460827,0.04385994346356848,0.24740390815307164,0.17328459408252173 +2,0.1329382701053118,0.08641101988419007,0.1658956004585934,0.045504541181782765,0.19866858720694477,-0.05871959016947563,0.23122106381895619,-0.19602046063379353,0.2635166923672206,0.17518177469758947,0.2955191005755992,-0.08125760752546007,0.3271946292138714,0.08238020291612126,0.3585055968790339,0.11044374803339405,0.38941826187403455,-0.09740021468319436,0.4204800842020212,-0.18706058979347182,0.4499089604913546,0.05455515643726392,0.47942543957854117,0.14369220782091904 +3,0.19866880397813208,0.07569944878855102,0.24740316552292724,0.0415514746130159,0.2955191005755992,-0.15920769477754748,0.3428967073565496,-0.19944875822861768,0.3894173096773208,0.2942784947111525,0.4349639061035205,-0.12424093300579578,0.47942543957854117,0.08178194243293531,0.5226871209692666,0.19454246319638996,0.5646423567677593,-0.09448483074511826,0.6060248170036572,-0.31646084048319356,0.6442135060693922,0.045456356318880335,0.6816386192303977,0.22999701028090566 +4,0.2635166923672206,0.08148909254286604,0.32719364707661835,0.036762583842814375,0.3894168847769294,-0.2149331081811167,0.44991043713013074,-0.137780459510568,0.5084052027756657,0.32500192197293243,0.5646403600268997,-0.038875860023334034,0.6183696753450455,0.16624018851145048,0.6693497019958143,0.1577922134993358,0.717355942729148,-0.07889291499163552,0.7632311729318405,-0.2196775313919245,0.8036030480328988,0.030138855532104488,0.8414708110015084,0.3724825361763016 +5,0.3271938291649748,0.15808485600062427,0.40471326513915407,0.06466726981444942,0.47942374418967615,-0.29012992984859837,0.5508073287584893,-0.009338350034207608,0.6183681633215211,0.34639528123348473,0.6816362087564646,-0.03435725501354441,0.7401767003120189,0.2353379841232961,0.7935776393349223,0.14218909558628462,0.8414708110015084,-0.10099295151332627,0.8847480125015222,-0.25229027875152843,0.9194390117768044,0.05439873691292068,0.9489844233421858,0.5197881426170552 +6,0.3894173096773208,0.1093040070860904,0.4794240004913979,0.0634119947267259,0.5646403600268997,-0.2951048719526334,0.6442156204319115,-0.0628476816067032,0.7173541886666447,0.5918235567629165,0.7833239777578337,-0.028717210622548818,0.8414708110015084,0.3248372017099977,0.8912071759819662,0.13854592261038334,0.9320388934539371,-0.09384739552397232,0.9648930766705357,-0.4549627585367087,0.985443334120646,0.08225754342511524,0.9974947805708058,0.5278870832139627 +7,0.44991068751229063,0.02436124338507627,0.5508073287584893,0.055083929897366816,0.6442152760320222,-0.3122591591151877,0.7288653629462123,-0.14320409379407573,0.8036061327442741,0.46362121986640603,0.8674199789650587,-0.03615306301164129,0.919444789341787,0.3084245513757743,0.9589660555042986,0.10596350259458488,0.9854495264431672,-0.09164386217785683,0.9999186954272832,-0.29610089850965643,0.99803872099732,0.07043982387897442,0.9839857436309893,0.7177817042201629 +8,0.5084052027756657,0.09721738025908783,0.6183678191902066,0.05733556464541125,0.7173534059482909,-0.2813354916901234,0.8036056855252041,-0.22092658378419633,0.8755928979700968,0.4126099613603387,0.9320355974912465,-0.08000409489538378,0.9719377006088968,0.32738203406528965,0.994583198503197,0.13196944522768964,0.9995733965789168,-0.10498310190874317,0.9881870328998575,-0.3752186596435092,0.9565425291927647,0.07410549162758007,0.9092972390096973,0.5377613391231191 +9,0.5646409761170574,-0.03890281288913384,0.681636573161724,0.027793399697221152,0.7833239777578337,-0.3533061173597477,0.8674204426910147,-0.16752851421805906,0.932036614453545,0.3194380099002905,0.9757197058473077,-0.1250112729241256,0.9974947805708058,0.31223221766233356,0.9968648225507883,0.10772011087940857,0.9738474297293235,-0.05435060081226761,0.9302466743354737,-0.4930855459100964,0.8632037641581153,0.06256476784212436,0.7780730361763881,0.655913272795306 +10,0.6183681633215211,-0.12378155753976353,0.7401744785300999,0.04217621001610438,0.8414678353141013,-0.2963939906157213,0.9194420294524152,-0.12779305942202387,0.9719353240488453,0.3439375153596155,0.9974912531369146,-0.10012387803902585,0.9954077521495933,0.34811192772181937,0.9657344542843688,0.17756403514764052,0.9092972390096973,-0.002597716402951101,0.8288069736600956,-0.40461895998540864,0.7230811886915083,0.11498446177100384,0.5984720204891397,0.4769077297661318 +11,0.6693480653170835,-0.16333808314677709,0.7935752572598723,0.0393910973677675,0.8912040244121957,-0.2990933885011083,0.9589631769842881,-0.20534744714928563,0.9945807665708869,0.2611814223410574,0.9968612973446134,-0.11290087347168024,0.9657344542843688,0.31625472585004094,0.9022354925697829,0.16310994483492183,0.8084962368241235,0.02235773865485984,0.6886113611541567,-0.22622988015017736,0.5437691487400569,0.10814643831108015,0.3816609132200015,0.7250295622631547 +12,0.7173541886666447,-0.1548719887252969,0.8414682851659445,0.07834239391218195,0.9320355974912465,-0.25240088438385544,0.985446568427954,-0.16204751289288136,0.9995709524446874,0.23253894232359504,0.9738439859193954,-0.03716586812829837,0.9092972390096973,0.36606272705327847,0.8084962368241235,0.11180584295279966,0.6754630410337846,-0.0018077423554732563,0.5162155355146849,-0.3390232473992746,0.3349859759811476,0.1280249518989127,0.14111997891143627,0.69061499472004 +13,0.7621732518327039,-0.10749146273124935,0.883521164203536,0.07749851922729012,0.9635545789701722,-0.227134619825772,0.9985321440790174,-0.02301125778510732,0.9868172986056173,0.2133432880418572,0.9289562380534917,0.11573990086914566,0.8276601811721678,0.31581808850216747,0.6876585526460597,0.1473546008632466,0.5155012653443151,-0.01897728433277572,0.319680912283269,-0.47067635828495935,0.10804727488884226,0.12553141079504324,-0.10819511218233183,0.5001783422199888 +14,0.8036061327442741,-0.09745270856158064,0.9194420294524152,0.16168783920818547,0.9854460416048247,-0.16499994847260543,0.9980419966437395,-0.014414062591243804,0.9565462009777684,0.17954378707502774,0.863206135791728,0.09821480307985395,0.7230857323844578,0.256515722817357,0.5437725656731565,0.1087667383838919,0.3349880809638813,-0.03137028783600982,0.10819766332341596,-0.5038099290319056,-0.12474735899242541,0.1296516865263856,-0.3507831552351123,0.6527590458580811 +15,0.8414687534561404,-0.08630108281320674,0.9489815747839533,0.16658207045861523,0.9974912531369146,-0.20543791189325872,0.9839827900096002,-0.17383700811973316,0.9092950156166836,0.28688678287081637,0.7780702846820935,-0.07700252789943487,0.5984720204891397,0.1838460879138721,0.3816609132200015,0.17896140060483928,0.14111997891143627,-0.02367782774153826,-0.10834502557034269,-0.33980911606950104,-0.3507809510000156,0.11781909169268824,-0.5715612006859756,0.710230762929154 +16,0.8755928979700968,-0.1683787958918241,0.9719347831520129,0.16015407573830234,0.9995698617944301,-0.23081665871083204,0.9565456686452177,-0.12726126841451088,0.8459025198570183,0.5187780000960701,0.6754606523985058,-0.04180270321099697,0.45727253218584873,0.22183060178803446,0.2067571035850584,0.14430496836230883,-0.05837413137036218,-0.024166478393756886,-0.3198213837119108,-0.37062366080128767,-0.5578029605792876,0.12137324462968649,-0.7568023389898726,0.7895574004897215 +17,0.9058269595847469,-0.16902460892909563,0.9881423494987066,0.23070555133215268,0.9916610988067627,-0.21253662258257405,0.9161036588614004,-0.18924771160813464,0.7674940227936468,0.3123417814877628,0.5576816300688126,-0.10739618137720164,0.30340040470028906,0.20313360490397186,0.02492340074937531,0.2373013518183809,-0.2555410492446477,-0.020538898905008443,-0.516342547528883,-0.40619563714367873,-0.7345932345595172,0.09895087372797683,-0.8949891733679401,0.753577791336926 +18,0.932036614453545,-0.11264313051297044,0.9974917863994587,0.2629510830189659,0.9738439859193954,-0.2627699299344032,0.8632065972649476,-0.14043201977455427,0.6754613894068556,0.261788679779552,0.4273782806180316,0.02639621636541322,0.14111997891143627,0.1501033399592531,-0.15774566156077083,0.24785074084622782,-0.4425203518919624,-0.024926683783861956,-0.688718974534551,-0.5470724378011553,-0.8715701156223173,0.1381452244458354,-0.9775299157556051,0.6636691642532521 +19,0.9541054183585924,-0.045204647021192194,0.9999182047737185,0.2719164604831875,0.9462965458347441,-0.1634547351306061,0.7985736553563189,-0.15351085934038766,0.5714383329324836,0.004189561320598864,0.28747693635747484,-0.05182901399217172,-0.025071380616657785,0.16129599131725675,-0.3351275508135408,0.331527427016503,-0.6118577645630672,-0.023931670874707624,-0.828890156288455,-0.5077790532169782,-0.9613097477752932,0.11990477950704977,-0.999292582570792,0.5646195757955813 +20,0.9719353240488453,-0.0809405340641556,0.9954047642428613,0.2647380191178602,0.909294023468162,-0.19819335130674062,0.7230835619043475,-0.20089707601163576,0.4572714140734101,0.09835143027300718,0.1411194798698312,-0.05056944898913816,-0.1905679235135467,0.2373306583850281,-0.501276945049239,0.37944749868119915,-0.7568023389898726,-0.05207385767435483,-0.9303015365019949,-0.34785399568795367,-0.9989484335773511,0.1262748821328722,-0.9589240765966958,0.666644711458042 +21,0.9854471168443051,-0.04408092769410418,0.9839827900096002,0.3300936140228323,0.863206135791728,-0.1740109647779066,0.6377626560314573,-0.1449486429206976,0.33498726185861305,0.19856648052622228,-0.008407215900141561,-0.12332105335558877,-0.3507831552351123,0.2964439956478614,-0.6506250026781153,0.440761213846951,-0.8715755923890364,-0.026087242791070525,-0.9882110091902089,-0.4152052834986769,-0.9824462362087195,0.09916979794690495,-0.8589343160131047,0.7140841900347635 +22,0.9945807665708869,-0.17520621354857277,0.9657315554477149,0.3073280918105685,0.8084933777444749,-0.09550176881645107,0.5437709334358083,-0.16130388679541083,0.20675659802727167,0.22916682322577894,-0.15774510372587924,-0.11325535588457732,-0.501276945049239,0.31160508293487893,-0.7781660137576312,0.5181695734129138,-0.9516018773354787,-0.02496576161388539,-0.9999106646841573,-0.38372656559318274,-0.9126975397134226,0.069321410563614,-0.7055401798405714,0.6571622130888845 +23,0.9992956942623932,-0.20237368465181715,0.9407777319110218,0.3240928737356791,0.745702421119178,-0.020819083768218345,0.44238627546146486,-0.33101684243892365,0.07485570355747272,0.2902472387428093,-0.30354037660019967,-0.03915293432881033,-0.6378785777625828,0.34073297855812773,-0.8796251767444788,0.48855952261140023,-0.9936907983859309,-0.10588932533583054,-0.9648534144201797,-0.505032208657691,-0.7934825752839127,0.04385463260925596,-0.508278972513879,0.5630873754943327 +24,0.9995709524446874,-0.21888182883006488,0.9092945095801384,0.2883045209049815,0.6754606523985058,-0.021852110054143656,0.33498707543309236,-0.20684100657225868,-0.0583739886352583,0.4203392932138093,-0.44251878701030367,-0.0567156707567736,-0.7568023389898726,0.36923479836852974,-0.9516018773354787,0.39499119090957796,-0.9961644030773823,-0.12032059306790124,-0.8846785733937238,-0.41042857191804605,-0.6312625407602636,0.06350738598296563,-0.27941544048547,0.4131988741508426 +25,0.9954053182011033,-0.19397991955498628,0.8715003954487482,0.26384178212475035,0.5984699041166742,-0.04887413849083492,0.22303349846197512,-0.21791232901451,-0.19056745754117643,0.47338528314241196,-0.571559179478056,-0.03452577726577143,-0.854752430689124,0.4420907663062728,-0.99168366710981,0.3788586812915607,-0.9589240765966958,-0.13226479990829104,-0.7631352040144949,-0.4402535455789761,-0.4348294174318781,0.06829607344714897,-0.03317920969436771,0.6114566398864039 +26,0.9868172986056173,-0.26617327000070784,0.8276576967918504,0.3019887498111387,0.5154994423807571,0.02051252259408932,0.1080476295090205,-0.25126019611708644,-0.31937807662254736,0.36060542916869276,-0.6877635849832111,0.007072230899562329,-0.9290143093822026,0.3893191043272919,-0.9985271217419511,0.32596222601754904,-0.8834544732420101,-0.1323212106755028,-0.6059068058872905,-0.41620942988372434,-0.21482946332720174,0.07912305668215713,0.2151199436546398,0.8510177964114687 +27,0.9738450484996455,-0.28172645738770385,0.7780707006413696,0.26825479102086597,0.4273782806180316,0.003805571022830546,-0.008407220394670538,-0.3105331496366426,-0.44251926985122086,0.4263853202209891,-0.7885223030999569,0.06840982487466114,-0.9775299157556051,0.4345358976629172,-0.9719028686545995,0.2377992247376011,-0.7727643279409716,-0.07898493599868872,-0.4203455492142622,-0.3501661521338785,0.01681379135703285,0.04775432836021079,0.45004398082371677,0.912952339487122 +28,0.9565462009777684,-0.27972446603143003,0.7230835619043475,0.2794845026334388,0.33498689634784823,-0.13453057247897232,-0.12474776842367415,-0.3608610723959134,-0.5578051017622284,0.31236710229467113,-0.8715725102422991,0.008289645460680358,-0.9989547107631299,0.4360771302910287,-0.9127032749163861,0.1961983048028075,-0.6312665074837803,-0.07449129381359164,-0.2151284793525788,-0.4388719550421398,0.2475457742649481,0.054276530604369384,0.6569864630177729,0.8329160520844267 +29,0.934997611335499,-0.4054742189383086,0.6630779147810708,0.23771498030668,0.23924843374130247,0.0036556374372045397,-0.2393922857642307,-0.442323157252805,-0.6631890809908589,0.28740539118269726,-0.9350490778034435,-0.044342777893656135,-0.9926949375583151,0.4813314717631015,-0.8229125377085427,0.18962050640635372,-0.46460208344987014,-0.11848956077252819,0.00014823122521962477,-0.45191781301930845,0.4648613020352531,0.0675544974471554,0.823080709003606,0.7725790027830468 +30,0.9092950156166836,-0.4866980728665873,0.598470224060917,0.24882736619302914,0.1411194798698312,-0.03285598097178348,-0.3507821022923873,-0.4352079345227285,-0.7568004884739346,0.4059779977431539,-0.9775264589233272,-0.12983629940721736,-0.9589240765966958,0.3854707214186822,-0.7055401798405714,0.23701475562867858,-0.27941544048547,-0.10000894457557132,0.21541801035036656,-0.46921270842502427,0.6569823346763795,0.05746263314673422,0.9379997830302238,0.7534117280349918 +31,0.8795526052775403,-0.5118802406245294,0.5297088946790524,0.3298756452246315,0.04158050680339868,-0.07970708038017216,-0.4574027977535819,-0.6212081347141821,-0.8369775854368209,0.44960574189980224,-0.9980507031977979,-0.10746362627203652,-0.8985780378788037,0.40648842505730326,-0.5645201793506656,0.33080812329563913,-0.08308938565532534,-0.11120979473953269,0.4206146099763294,-0.35101415793754276,0.8134963203155482,0.062141636981296274,0.994598573676141,0.7848618163058243 +32,0.8459025198570183,-0.503875277800938,0.45727115959489356,0.3304051297180357,-0.05837392494232528,-0.11250380088297358,-0.5578047913351813,-0.6712641792122447,-0.9022971116760534,0.4557221897227917,-0.9961608803480958,-0.29783659587158684,-0.8133292235738566,0.41081514979839323,-0.40457911407154984,0.34925274277647944,0.11654918077717834,-0.1535473624878327,0.6061428148409641,-0.5082727519552203,0.9259205333923181,0.05544743878647134,0.9893580422707821,0.537503657279199 +33,0.8084942599074373,-0.5718707456920367,0.3816597675917808,0.34289784848716953,-0.15774510372587924,-0.0587276181862525,-0.6506230497027428,-0.6831305202539154,-0.9515995505001217,0.6031215802319095,-0.9718994317212093,-0.30968985455403764,-0.7055401798405714,0.3705955218316874,-0.23107774057012234,0.47130585841099726,0.31154129916430334,-0.18824706178276376,0.7633271251254613,-0.4283163282445873,0.9881618203652709,0.04719771911842727,0.9226040196748309,0.6851307073815349 +34,0.7674940227936468,-0.6570170423410816,0.30339949398595134,0.28177756492025946,-0.25554014557660326,-0.11992769323395744,-0.7345956455558522,-0.754918527233537,-0.9840097118389171,0.565922120086364,-0.9258112171487184,-0.3343666831947572,-0.5781981223170799,0.30490096742081396,-0.04983131727135926,0.514049021717953,0.4941132490791689,-0.1835900460139457,0.8848174322229501,-0.20483100072709196,0.9968468356911052,0.07667635942183935,0.7984869476954495,0.5933998351597874 +35,0.7230839643116026,-0.5970227670000505,0.22303349846197512,0.20773962247789202,-0.35078191476306736,-0.07115556207879387,-0.8085809143873596,-0.6653991044461588,-0.998952268141697,0.6600930182805015,-0.8589312785696256,-0.4456846749801294,-0.4348321498101907,0.23355873290347698,0.13308530599033094,0.5879046966935789,0.6569864630177729,-0.21249674461609228,0.9649327177784034,-0.4052416724537488,0.9515048699875486,0.09412632624342786,0.6247238247170469,0.5955038742638565 +36,0.6754613894068556,-0.5874007583826399,0.14111955531283146,0.1984309501338832,-0.44251878701030367,-0.09239052406983225,-0.8715729761882137,-0.7439469909747064,-0.9961619672787466,0.5868257897774732,-0.7727615952198251,-0.4000373223836527,-0.27941544048547,0.2908924926809972,0.31154129916430334,0.5525185181834922,0.7936676999165313,-0.20899180115099678,0.9999267042604486,-0.3533431748471735,0.8545933614820114,0.13726763976756845,0.41211840011841105,0.3259218243065448 +37,0.6248378755884137,-0.641973589033756,0.05822618224837239,0.15265504605667682,-0.529834157814996,-0.10654507746151437,-0.9227154110697628,-0.7944490955728652,-0.9756883411528704,0.45438117266138056,-0.6692373524165389,-0.34910955372547253,-0.11625514104764849,0.2154553173251945,0.4795553394699432,0.4686390936839719,0.8987079101828758,-0.2125481999767755,0.9881630349566065,-0.2863188338367578,0.711364708079664,0.1516748759610729,0.17388944946344542,0.09198360175948733 +38,0.5714383329324836,-0.7567854515571073,-0.025071305360114156,0.19249117543385022,-0.6118556008546875,-0.1374556969245217,-0.9613129028741557,-0.7102147799338546,-0.9378948264360029,0.5211912634973803,-0.5506834814681028,-0.3846597488494382,0.0501269995284057,0.233356830350503,0.631496087836649,0.5642277889877553,0.9679194721070383,-0.20313336809244087,0.9301917917856273,-0.30002358958127845,0.52958159855392,0.16929565344098835,-0.07515110493929571,0.14068339267663424 +39,0.515500004852297,-0.7454765099093049,-0.10819478741401113,0.27649590443690614,-0.6877635849832111,-0.10322373659505071,-0.9868406924700296,-0.8163235756917707,-0.883452313039142,0.31022470103545724,-0.4197624467290389,-0.5015428042448056,0.2151199436546398,0.1476638485572891,0.7622709349138133,0.44902404966719955,0.9985431391248171,-0.1653392729075555,0.8287237728711316,-0.3674563381664206,0.31909629130413975,0.1938063426842304,-0.31951912762537305,0.12600304941546492 +40,0.4572714140734101,-0.8038221537316022,-0.19056735148747966,0.28907656980120305,-0.7567996627149973,-0.1120335910183939,-0.9989517122095232,-0.8412669726308702,-0.8133272348396607,0.16671750102521368,-0.279414452390581,-0.4728183709223925,0.374151153290036,0.08204501595154011,0.867496690436754,0.6424191247461458,0.9893580422707821,-0.19865662476453644,0.6885037326850879,-0.3639223267114612,0.09131664287817737,0.17752479330468524,-0.5440209985214498,0.26972662657026014 +41,0.39701125849728375,-0.8749078094902892,-0.27161729684580566,0.2356039762795395,-0.8182740483816434,-0.16868715913709328,-0.9974813045488092,-0.863353526743569,-0.728764414072919,0.22850150327964341,-0.13279141183324866,-0.48377933399166423,0.5228133111422222,0.13187538928241327,0.9436464949404428,0.7269314848359838,0.940730362371255,-0.15130648817056905,0.5160885121893056,-0.35600327945060084,-0.1414121736203554,0.17678174739765315,-0.734698278652349,0.4408443177930563 +42,0.33498726185861305,-0.8169439520381667,-0.3507821022923873,0.14639555889422715,-0.8715725102422991,-0.2548831662329098,-0.9824494606793018,-0.9463245050196543,-0.6312649639252141,0.32835981179907586,0.016813837552559784,-0.3758232426917584,0.6569864630177729,0.0976800310457925,0.9881680297701986,0.7474350132141053,0.8545987315703112,-0.16644296499199246,0.3195404338498622,-0.3751700489450812,-0.36647675069115965,0.16465222964561166,-0.8796955782699294,0.13757903845252992 +43,0.2714749842826905,-0.8080333618750497,-0.42751233030580743,0.14740883584282755,-0.9161625076841808,-0.13782391428429158,-0.9540605487338251,-0.9999845285483123,-0.522559641589201,0.18175937979285645,0.1660414843954216,-0.5216809293294343,0.772952196324663,0.16751228842831362,0.9995690633356505,0.6020398052628806,0.7343969461839074,-0.15962455021469577,0.10805029870568993,-0.378447686595194,-0.5716790663308193,0.17880963468253838,-0.9699976675669778,0.36154973153002856 +44,0.20675659802727167,-0.8220823849154185,-0.5012754403706232,0.13297748318130478,-0.9515985121923437,-0.0808684049336967,-0.912700535262959,-0.9269805064194142,-0.4045781248038636,0.08372887781588584,0.3115401974629625,-0.5656101908220632,0.867496690436754,0.13875940878369386,0.9774674663478088,0.6692735191346982,0.5849170720767289,-0.157752449835368,-0.10849238544324119,-0.23314518141875534,-0.7458975909148143,0.1598202578634686,-0.9999900000020655,0.5194601110964426 +45,0.14111963384806014,-0.9353943783477018,-0.5715594850357251,0.11464645577439447,-0.9775264589233272,-0.06356612301941582,-0.8589317377574907,-0.8756233225692146,-0.2794147572651636,0.0991667701525068,0.45004238933630103,-0.6461607434263958,0.9379997830302238,0.1678984280049077,0.9226040196748309,0.721759764022878,0.41211840011841105,-0.23127035608467914,-0.31996184813270867,-0.29943401784026696,-0.8796900504792157,0.18010852347007422,-0.9678077976098799,0.6114857677100735 +46,0.07485570355747272,-0.8491012008193805,-0.6378766630480424,0.12220632742832672,-0.9936872844040355,-0.08987481934800527,-0.7934851795599813,-0.766908416393876,-0.14929137097622078,0.15237122391952754,0.5784375993789701,-0.6117991310087655,0.9825075840487404,0.1793407514530167,0.8368175855346323,0.8065341132813423,0.22288986806218858,-0.2050192454234533,-0.5164695482291165,-0.5497977656103833,-0.965805177912614,0.2083546025546788,-0.8754519938632033,0.5260905450315775 +47,0.00825920445208658,-0.9131232777065567,-0.6997667049013352,0.1119181648525825,-0.9999195150083239,-0.16855399941805094,-0.7172506511056325,-0.6890468403636137,-0.016517845494420102,0.20637745713656913,0.6938423479770602,-0.5324137632326444,0.9997866249075904,0.13277955860697319,0.7229834742133202,0.7634180682540304,0.024775420335977266,-0.23869510222548215,-0.688826572823913,-0.5699829507821744,-0.9995757162306776,0.17497996963546517,-0.7286648253209377,0.57118989504468 +48,-0.0583739886352583,-0.7656640341842935,-0.7568000673029196,0.08237671474894245,-0.9961608803480958,-0.17849496304311105,-0.6312646126166332,-0.5787484152767336,0.11654889579376129,0.24254626136297167,0.7936648932749352,-0.4746627991978376,0.9893580422707821,0.16599498765525916,0.5849170720767289,0.6575433639089356,-0.17432674521566777,-0.25279639610980154,-0.8289733207543877,-0.6048477471997559,-0.979171373990692,0.1843099772803257,-0.5365728071709441,0.7022343562033321 +49,-0.12474783784774864,-0.6523912637263265,-0.8085809143873596,0.11289031949066372,-0.9824489354584368,-0.09422067788346132,-0.5366961027210728,-0.5230271726105937,0.24754672449437742,0.22885376178056327,0.8756634361098425,-0.5559903758498261,0.9515108490477658,0.0878741419329197,0.4272459609821542,0.7658523118797833,-0.36647905355542204,-0.2815461084215164,-0.9303563782839893,-0.40926665882610663,-0.905698023199917,0.22555259532572558,-0.3111192907192189,0.7834427676018797 +50,-0.19056745754117643,-0.5079685781093141,-0.8547498649862385,0.20540138952201537,-0.9589206855601181,-0.15812036703627605,-0.4348308445783347,-0.47668282983117616,0.3741502384241103,0.06603701772245002,0.9379964659880807,-0.5871835792546191,0.88729392482351,0.15617929072057038,0.255254815273456,0.6429011797594453,-0.5440209985214498,-0.28464213620892975,-0.9882349638271357,-0.48978539468110077,-0.7831377634018479,0.2616275804165923,-0.06632188365236895,0.6870332826712016 +51,-0.255540424401462,-0.5361299856209372,-0.8949864868867765,0.18584982467761785,-0.9258112171487184,-0.16292832397479162,-0.32705376559898175,-0.5448814701734853,0.4941120408846437,0.06728394139525266,0.9792641174406228,-0.5362334822119769,0.7984869476954495,0.09426598510640612,0.07470827495848803,0.6663392009191567,-0.699874543033963,-0.2936110117135072,-0.9999026120312465,-0.6774045581131184,-0.6181331003395811,0.2872425238055302,0.18259909691516205,0.4485153982970888 +52,-0.31937807662254736,-0.5837206065186612,-0.9290115207680862,0.16796268978505094,-0.8834513490880687,-0.23302537835877535,-0.21483016841543884,-0.5302503604278089,0.6053026352261208,-0.00728647100318081,0.9985396079836227,-0.46815864645736405,0.6875510730989038,0.10402492769864827,-0.10834226815830302,0.7271120776213198,-0.8278262980975494,-0.2892939797126546,-0.9648137310282042,-0.6967073712994301,-0.4196269362543102,0.3154831910833612,0.4201669500408618,0.3806559172429574 +53,-0.38179679637265995,-0.4690412138964364,-0.9565888183799907,0.08968142782436918,-0.8322643271775211,-0.08131980340344482,-0.09968580909690142,-0.5341209668481497,0.7057482265130745,-0.09469785320628528,0.9953900516549766,-0.3532490025955926,0.557560726485169,0.11859388455322865,-0.28776149542454066,0.7443057254586575,-0.9227752310129332,-0.28029118932997565,-0.8846091149010771,-0.7283693753040252,-0.19837788358002445,0.27551833230971645,0.6316108572585717,0.20448297267031548 +54,-0.44251926985122086,-0.5433102151684175,-0.977526981512618,0.09669995568007064,-0.7727615952198251,-0.09581167064225259,0.016813846541300047,-0.5009801804741575,0.7936657592582477,-0.21224328831597916,0.9698861806997854,-0.46656852911409236,0.41211840011841105,0.12429039573954802,-0.4575357992709801,0.6226110276352584,-0.9809360274534649,-0.19516560115450038,-0.7630392183755283,-0.561590188820998,0.03362282899735623,0.27781381119351617,0.8037842605294165,0.550661136721583 +55,-0.5012757193382068,-0.6879107881674135,-0.9916806903816313,0.07749476779922461,-0.7055376848437148,-0.07648551692613569,0.1330849065093335,-0.6115855993325966,0.867494569253495,-0.22830337562144812,0.9226007570766193,-0.44038436629766664,0.255254815273456,0.10829849131601803,-0.6119748417431544,0.6694395020494619,-0.9999900000020655,-0.20664153590721213,-0.6057887814944514,-0.5784987804587369,0.2638012556634479,0.29363231381560684,0.9259822515463416,0.6253838386849155 +56,-0.5578051017622284,-0.737157203994413,-0.9989517122095232,0.07015202763956879,-0.6312642751403974,-0.08733306588028837,0.24754658673084953,-0.5118915959373198,0.9259240876317014,-0.1431742304848978,0.8545957094587294,-0.3940183651788089,0.09131721669311393,0.1545065494414948,-0.7459022779813261,0.7529089307136472,-0.9791775269015103,-0.20976449155106006,-0.4202110050159997,-0.47490539140551946,0.4796822146294197,0.2651286078509319,0.9906071510842663,0.619054219511175 +57,-0.6118562684623196,-0.7595854393989256,-0.9992895830029984,0.08248116894654765,-0.5506834814681028,-0.12808440657889103,0.3586427028767836,-0.49823055072918093,0.967917105372268,-0.13082531556713486,0.7673982846054301,-0.31303080377335313,-0.07515110493929571,0.1464576413375254,-0.8548292521556993,0.8010841156336324,-0.9193283357767612,-0.1488037194499353,-0.21498370678133918,-0.4325140927773877,0.669565416485643,0.26455147785133115,0.9936408958955364,0.2937997370244784 +58,-0.6631890809908589,-0.8016739474932879,-0.9926919577946147,0.10292765215750962,-0.46460044048073684,-0.10853460481389234,0.46486282774886273,-0.46959222073148593,0.9927281852713374,-0.15759295293669015,0.6629667486924334,-0.34067671576111597,-0.2395367236538371,0.08824047786416422,-0.9351048507807446,0.8564878506137047,-0.8228284250129188,-0.19882595977122058,0.00029646244719124513,-0.26201860140331085,0.823159595034361,0.2807426517251061,0.9348948624214914,0.6207361817318204 +59,-0.7115754780003251,-0.892101863563142,-0.9792046269178578,0.07480715224332651,-0.37387526546857375,-0.1448246642811204,0.5647628265638364,-0.5738305893618623,0.9999168945374642,-0.17281667045546822,0.5436464088753101,-0.2598833493393143,-0.39728393254468647,0.09548814630460115,-0.9840384706537116,0.9035027171857681,-0.6935249415290577,-0.22515978877574455,0.21556276877057062,-0.2922873605552811,0.932140272057291,0.26229757364983214,0.8180215944917583,0.49439333443520733 +60,-0.7568004884739346,-0.9223558574338026,-0.9589211982026931,0.1390456376724821,-0.279414452390581,-0.15344837536581027,0.6569844909472605,-0.5723420883706045,0.9893556231149057,-0.25353585439795173,0.41211694274696403,-0.26644788218508997,-0.5440209985214498,0.1844627175706323,-0.9999900000020655,0.9999859918073343,-0.5365728071709441,-0.24113082325426516,0.42074912653423857,-0.059486813387637213,0.9906009263525001,0.2670904167817131,0.6502877058397336,0.5703964095999233 +61,-0.7986631867985713,-0.9999849682235472,-0.9319824472972914,0.20374538720638635,-0.18216182246643972,-0.18739900291242914,0.7402740046993728,-0.5710917948522494,0.9612318489479014,-0.07953656249269177,0.27133221855593354,-0.16241566092230295,-0.6756813184953772,0.27645788006987887,-0.9824247902316056,0.8931276414540231,-0.35822920824433374,-0.277264690711573,0.6062607993966257,-0.13053433186011273,0.9953731146371415,0.28317207149668555,0.4421220772559133,0.484002250842023 +62,-0.8369775854368209,-0.9598910011075292,-0.898575340624977,0.22251967041904414,-0.0830890918269057,-0.265507180797345,0.8134989902783991,-0.5083513557642414,0.9160448099742079,-0.2393302887972379,0.12445395769369237,-0.13668106563171853,-0.7886161193340099,0.18319536095878391,-0.9319315757831196,0.8289133436196935,-0.16560414124265752,-0.21421819476742096,0.7634230605932549,-0.17251690228979205,0.9461981944430928,0.34337221151748365,0.20646743929180175,0.4394432788193297 +63,-0.8715734612321387,-0.8884657669889602,-0.8589317377574907,0.2723682098589931,0.016813837552559784,-0.3369410542934259,0.8756639042427814,-0.42500634209721366,0.8545966419248491,-0.2695508352166868,-0.025219270751534132,-0.17602685563477913,-0.8796955782699294,0.20739027680012664,-0.8502027414763919,0.6798903833052802,0.03362304027627407,-0.26754277458807085,0.884886832556487,-0.19777028461354473,0.8457413419893308,0.35231780791195405,-0.042024344038682965,0.33807427782072297 +64,-0.9022971116760534,-0.9172086712685874,-0.8133267822106491,0.25032787754789093,0.11654876862511672,-0.34945206579421517,0.9259235723408239,-0.3885734963687396,0.7779781392470327,-0.4672202141037717,-0.17432612874520476,-0.07116184641755208,-0.9463955613594388,0.2258725917529856,-0.7399775987224757,0.7729473647357156,0.23150977728303157,-0.29361042732436055,0.9649723377429144,-0.19015285267461804,0.6994471052921291,0.3417889450074968,-0.287903257198445,0.452636910748443 +65,-0.9290120377774551,-0.8462840309848875,-0.76207699183097,0.308732779697949,0.2151191829274714,-0.3867928470852793,0.9635946802290106,-0.4037150173161997,0.6875493919145926,-0.39617490578450215,-0.31951799771209266,-0.13658811391163736,-0.9868675761616822,0.2291101809632378,-0.604950571819621,0.7449234440554018,0.4201669500408618,-0.2838498939952181,0.9999346911834778,0.013841376496749472,0.5152443212276875,0.36835334100302025,-0.515881740262373,0.1350075527187752 +66,-0.9515995505001217,-0.8651875291325607,-0.7055380620267925,0.31520135640988595,0.3115401974629625,-0.4365306973335728,0.9881650635948778,-0.47812892700730086,0.5849156418507524,-0.3696189686974637,-0.45753418129029866,-0.05197396824476862,-0.9999900000020655,0.2294308591472208,-0.4496473716596205,0.8320115228485158,0.5920733924140472,-0.30168937090402514,0.988139015360982,0.17694679792553278,0.3031163894155759,0.3750061086179097,-0.7117851953493901,0.019805396154623078 +67,-0.9699592980742013,-0.7905302867925699,-0.6441023970904011,0.272737524050207,0.4048484053269021,-0.387641166255108,0.999300671378899,-0.41345830409903694,0.47189878754687226,-0.49107510232691576,-0.5852751335418936,-0.1554769899161486,-0.9853991641036104,0.18400428823287104,-0.27927330713963194,0.9034502025459354,0.7403757370273192,-0.25136271823692763,0.9301368888536573,0.17049841977321042,0.07456019316573098,0.391071387768644,-0.8634332944651512,0.01774510522033419 +68,-0.9840097118389171,-0.7424247203880844,-0.5781963867448399,0.2846552841198958,0.49411150174990476,-0.41465459823078377,0.9968501074256576,-0.34163545768527503,0.3505050427116589,-0.4291470767505771,-0.6998720680724586,-0.3130809124442334,-0.9434994321350132,0.2297506389497506,-0.09953881843717356,0.8922517889161251,0.859161637396055,-0.24732615339055192,0.8286405539233858,0.2067636894133297,-0.15803700717570207,0.3935374949638586,-0.9613972933022695,-0.1495169370454631 +69,-0.9936883686356974,-0.7019699690877867,-0.5082774468173386,0.3729945962360282,0.5784375993789701,-0.3448903312591691,0.9808466888174643,-0.31285151635137604,0.22288932305693154,-0.5383785357928811,-0.7987513847887721,-0.35646260268353464,-0.8754519938632033,0.19396347146167583,0.0835319202375313,0.8239925806407276,0.9436954745231406,-0.277690372190899,0.6883960891297026,0.37251767526876245,-0.38206893742455966,0.3919272085482479,-0.999586264893971,0.044861376551148316 +70,-0.998952268141697,-0.8062979117986191,-0.4348308445783347,0.40834873588936654,0.6569841397210892,-0.3415318887526639,0.9515079929059833,-0.3194786169314604,0.09131699340632397,-0.48495199343672557,-0.8796924674085422,-0.33746635269403114,-0.7831426844779176,0.06629353420159725,0.2638029133360913,0.757779297371044,0.9906071510842663,-0.33757335207111727,0.5159614775555277,0.4446278617880872,-0.5853935428416657,0.3574890499293087,-0.9756258039519613,0.019972445485279395 +71,-0.9997780239115723,-0.778197921004047,-0.3583663307796047,0.4559064603262504,0.7289663117091982,-0.3793337552717201,0.9092328992002714,-0.1880336080351496,-0.04187634582854418,-0.6112507620734513,-0.9408775537147501,-0.3179668752937279,-0.6691297214286311,0.06852414569757703,0.4352320045905824,0.72523702201256,0.998026446573297,-0.28007330649650575,0.3193999484147674,0.510950121467294,-0.7569910617080028,0.37621384765734733,-0.8910056558870083,-0.14238963619029532 +72,-0.9961619672787466,-0.9280684245355669,-0.2794146017665894,0.4234250495276733,0.7936648932749352,-0.37614049911379,0.854596166328783,-0.1932051849796462,-0.1743263189558588,-0.530060574042246,-0.9809325585761778,-0.35373004451927864,-0.5365728071709441,0.09038972717837572,0.5920733924140472,0.6977683938457208,0.9656575770920255,-0.25267224577515895,0.10790293172039359,0.5916541590372538,-0.8875612730020958,0.3419752137019912,-0.750987091654764,-0.16262461833939862 +73,-0.9881201636545577,-0.8779767350494935,-0.19852361621935746,0.4720849196708154,0.8504334375778019,-0.4509081426629696,0.7883406178052164,-0.09797750493963203,-0.30368174611450327,-0.6162949652890368,-0.9989579329371041,-0.4977091444121057,-0.38914556263605904,0.06871000288420166,0.7290702123591435,0.7339339716780069,0.8947909873207962,-0.27799084295669496,-0.10863974293888247,0.49847557390822966,-0.9700275486242032,0.34331186651522505,-0.5642757874102943,-0.28292226175235763 +74,-0.9756883411528704,-0.9430808572785799,-0.11625479208560974,0.5190028559175055,0.8987047320883108,-0.4276770455650107,0.7113670428380355,-0.12513417626404552,-0.42764638013975387,-0.5452099206464335,-0.9945488657497844,-0.49925061928234266,-0.2309337182016461,0.06337581469053871,0.841630731834346,0.7198858553708735,0.7882519045613308,-0.23993438371903913,-0.3201023055425866,0.4059193139990852,-0.9999203916074106,0.33348647306255896,-0.3424805477300144,-0.4536670191773511 +75,-0.9589217318572629,-0.947150808252016,-0.03317911010062455,0.5257381824854882,0.9379964659880807,-0.3733158874090307,0.6247219494890134,-0.06285620426985472,-0.544019668293626,-0.46069082187063426,-0.9678043751578611,-0.5834440554866966,-0.06632188365236895,-0.024341078302173667,0.9259822515463416,0.7686964196077406,0.6502877058397336,-0.2399726021719177,-0.516596537612602,0.5495674803109192,-0.9756196733592968,0.3280989487164913,-0.09939152636945847,-0.39967895046742047 +76,-0.9378948264360029,-0.9615484538168197,0.050126849062630215,0.5780227339219672,0.9679160492601058,-0.37047700950867873,0.529583336684947,-0.07418153782307127,-0.6507358152798379,-0.5493408059339998,-0.9193250847623279,-0.4434953792855477,0.10012796486507118,-0.11862277238300427,0.9792975553425121,0.7674734166511675,0.486398588387829,-0.18831443220412133,-0.6889341560198833,0.35237052813961556,-0.8984424411902245,0.36771649349579816,0.14987717870571562,-0.5151568473794634 +77,-0.9127010431950106,-0.8781926856622723,0.1330849065093335,0.5214848949424539,0.9881645353184317,-0.3824804670525291,0.4272446785216949,-0.18986589953287736,-0.7459004541179705,-0.5303012517473864,-0.8501997349103194,-0.34214432154232277,0.2638029133360913,-0.09154432069929787,0.9997896702247545,0.6976740628030432,0.30311829413640534,-0.20149498884189565,-0.8290564670560703,0.5416791807883498,-0.7725715371467465,0.38671301287958887,0.38982724672728647,-0.4619710991241565 +78,-0.883452313039142,-0.8431511050236952,0.2151192979309889,0.5325098391558969,0.9985396079836227,-0.42720581310209205,0.31909733860481265,-0.22589730797969368,-0.8278242739154351,-0.6818218268856214,-0.761980731934101,-0.33401798741842376,0.4201669500408618,-0.13243847890535898,0.986771760456212,0.6959372399978908,0.10775363004285596,-0.2836021594164605,-0.9304111996802535,0.7351743158521509,-0.6048288968577106,0.3777973083944079,0.6055397446449408,-0.5649667779054948 +79,-0.8502785821857195,-0.8794082156242914,0.2956606696553515,0.5759029252768927,0.9989376029591515,-0.3974901586357941,0.20661165351579477,-0.03544159996752629,-0.8950530075953472,-0.5036579579440729,-0.6566492844516307,-0.3078047161324786,0.5648866744823758,-0.20473202957686668,0.9406801482856912,0.6172316303541268,-0.09190683124426098,-0.3131735607923452,-0.9882588968101124,0.7887774856055605,-0.4043058151140909,0.37739034679628153,0.7836027141246634,-0.7222382071812826 +80,-0.8133272348396607,-0.9466192517053802,0.3741500302037959,0.6612191417014436,0.9893545436107821,-0.33077080534661024,0.09131694258702432,-0.020514579751780816,-0.9463932472544476,-0.4517852488346605,-0.5365709096922384,-0.36775879850469073,0.6939513912409081,-0.12121460550769086,0.8630596897024333,0.5939620878189663,-0.287903257198445,-0.3202147531714932,-0.9998945374687279,0.9248994997675027,-0.18187021700270053,0.39381668176586654,0.9129450621581827,-0.6628229442670751 +81,-0.7727624383951397,-0.875678058455289,0.45004262993064087,0.6639079653456622,0.9698861806997854,-0.3221248359771876,-0.025219284233850333,0.11540229742839081,-0.9809336288909083,-0.37402552261520006,-0.4044423090794041,-0.2689933054629357,0.8037842605294165,-0.0765331394426572,0.7565119953831572,0.535603105098267,-0.4724218888193926,-0.38075274327629394,-0.9647740264954793,0.9999784543319578,0.05042236054826496,0.4112977761867202,0.9855249080042565,-0.5094383302468176 +82,-0.728764414072919,-0.8275759991429337,0.5228117418180848,0.6869923458663679,0.940727035672983,-0.4374358957254711,-0.14141263774691729,0.03413965001647246,-0.9980610104422466,-0.48142853475520947,-0.2632308061305301,-0.24362985607901788,0.8913414250748817,-0.06600746282446375,0.6246082323137356,0.5024056217947194,-0.6381065505465904,-0.31923092809904063,-0.8845396370251033,0.8434089512810761,0.2799821532628256,0.4070156862890952,0.9968295883829464,-0.5403481224155495 +83,-0.6815286362337613,-0.7861775773179335,0.591952317186408,0.6710819677626526,0.9021684570687498,-0.3772177842281423,-0.25568339034789045,0.04692166156020406,-0.9974713559496301,-0.48598219300931245,-0.116107706767913,-0.2318046707151435,0.9541963661690775,-0.11564387066221349,0.4717694287200295,0.5020528382894593,-0.7783519177651622,-0.30766652689408064,-0.7629432160170414,0.8995716313873495,0.494367507998244,0.44256761541005585,0.9461562330216352,-0.5087795714547242 +84,-0.6312649639252141,-0.8107015905342001,0.6569844909472605,0.6159079203184137,0.8545957094587294,-0.430462980918681,-0.3664779534984565,0.12559124889303322,-0.9791751326387994,-0.42117755368988674,0.03362292137534937,-0.11737796411506829,0.9906071510842663,-0.030959243145450097,0.30311829413640534,0.43754885410695904,-0.8875668502539472,-0.26929354229952496,-0.6056707438277239,0.8874335313480314,0.6819591939172774,0.44080916992685065,0.836655465724281,-0.6368561265866148 +85,-0.5781967085201272,-0.8919667740873336,0.7174569120720905,0.5255859337220001,0.7984841240115808,-0.4067153770764977,-0.4722899998182455,0.20882621818384256,-0.9434971251115717,-0.36892699014668373,0.18259845119123896,-0.21696472406809691,0.9995647081900849,0.044042576619157114,0.12430752116970865,0.46587328974256853,-0.9613972933022695,-0.22220524599131877,-0.42007645161018037,0.8020274006821616,0.8325901401184119,0.47125164235719597,0.6751355138430856,-0.7993619464243936 +86,-0.522559641589201,-0.8229889188356848,0.7729498761608247,0.5700994531527581,0.7343943491410676,-0.462469314396966,-0.5716809426292507,0.2486268808342447,-0.8910706699462959,-0.2873694764255082,0.3274732134522769,-0.22515345819534727,0.9808207918973785,0.1491175050987566,-0.058669676218775134,0.4830530485080047,-0.9968998601312286,-0.28612055729888347,-0.21483892949943206,0.8319113610335723,0.9380964692822186,0.4526235157853081,0.47163890567684835,-0.6835371370159699 +87,-0.4646009474153787,-0.7894232831721396,0.8230782383694502,0.5672688287631938,0.6629667486924334,-0.46190761877517217,-0.6632994945193492,0.2655283113857519,-0.8228264130514906,-0.3408925469024966,0.4649936333295361,-0.21888157927109173,0.9348948624214914,0.12049007387637291,-0.2396804377108924,0.429506598652466,-0.9926591754361821,-0.27713928784391306,0.00044469366266863564,0.7895762950239942,0.9927599624806462,0.4634215424988555,0.2388180747015414,-0.5609635561002569 +88,-0.4045781248038636,-0.7406680408169154,0.8674940864795357,0.5814089080465407,0.5849150036385338,-0.42097209226501914,-0.7459000390129907,0.32094199078119523,-0.7399757893460115,-0.4076867242690937,0.592071298669006,-0.34209972978057657,0.8630596897024333,0.1595461054722756,-0.41265781216651864,0.3631397371547436,-0.9488443019336659,-0.331661605892851,0.21570752246741887,0.8211112334544407,0.9936179744965047,0.459257807974731,-0.008851307462160092,-0.3710523911710543 +89,-0.34275784306678664,-0.7342108243119485,0.9058891555574474,0.6334505664108888,0.5010189812154733,-0.4470089680604411,-0.818359565573682,0.21438396075449748,-0.6439895180254736,-0.7134897791537034,0.7058523190703563,-0.3603408225143748,0.7673060804482136,0.14162731624831118,-0.5718041040756225,0.3864498562828788,-0.867202000364398,-0.3477099956124639,0.42088363387280303,0.7311778467875832,0.9406240029010927,0.4817804827371011,-0.25597035819847547,-0.3139185935873256 +90,-0.2794147572651636,-0.6558904579331776,0.9379969674444899,0.5328528376141131,0.41211694274696403,-0.33139339334838114,-0.8796929376954162,0.14174968205872077,-0.5365714951552691,-0.7764642960286903,0.8037814181126971,-0.33877436880290074,0.6502877058397336,0.1794806568109022,-0.7117851953493901,0.32303323334507983,-0.750987091654764,-0.3274887059502481,0.606378770668057,0.7404066651013402,0.8366502083870517,0.47846870143564546,-0.48717441183429194,-0.5404137236230665 +91,-0.2148302879717708,-0.6399666399105483,0.9635946802290106,0.4960753350103633,0.3190971680142765,-0.32555246402093396,-0.9290662865029509,0.12581955719791166,-0.4196285470346411,-0.6783660054961869,0.8836593193541478,-0.4151859315277886,0.5152475589166583,0.19797059288768226,-0.8279093293524753,0.295602129012532,-0.6048326974776649,-0.36359038068922617,0.763518979333118,0.8190597400629135,0.6873317497597663,0.4946449355570361,-0.6880883201331702,-0.5255359112716173 +92,-0.14929137097622078,-0.7799498270681086,0.9825046348643299,0.5036351783871034,0.22288907985829023,-0.22107480496546525,-0.9658083477658534,0.13637375874223193,-0.2952365812536673,-0.5312669428331785,0.9436921373393631,-0.3976089695703728,0.3659280797702288,0.16702785001741455,-0.916284364858694,0.2669391049597387,-0.43456553231208034,-0.4196193616200738,0.8849562135006114,0.7680236500167861,0.5007613703001077,0.4746560525700237,-0.846220229387787,-0.2808499407557628 +93,-0.08308918248704147,-0.7613553603103683,0.9945955881983128,0.4848653043361044,0.12445395769369237,-0.1641379805770885,-0.989419588191898,0.2918964311174606,-0.1656037363111622,-0.3881925204227681,0.9825316643994714,-0.371683679907887,0.20646743929180175,0.1702820149039759,-0.9739482292376488,0.3385864327343184,-0.24697361072404783,-0.46502723677901225,0.9650119365631997,0.7701864651636376,0.2870507882689059,0.47361161187332473,-0.9517382633801457,-0.09320757521665134 +94,-0.016517845494420102,-0.7103682682417162,0.9997836238568343,0.43450804979377927,0.02477533272282974,-0.23318807844486225,-0.9995789969216438,0.32393187414319224,-0.03303118394141441,-0.46235654813888505,0.9993056484897552,-0.43685309978089204,0.04128485582125012,0.21014115703358302,-0.9989681984631305,0.36734121788926627,-0.049535630646748055,-0.47108081024093806,0.9999426561961957,0.8107737492950738,0.057782663256164456,0.4612057864051722,-0.9980818218248938,-0.2042300013568377 +95,0.05012687695900181,-0.6297206702405063,0.998032734659214,0.429728085759554,-0.07515083918296325,-0.23914062441552567,-0.9961484498892272,0.23135799825122866,0.10012772003443934,-0.6009551739092783,0.9936373820901109,-0.3655226032209624,-0.12504187702344943,0.12559085311178117,-0.9905056763678888,0.4179766524517116,0.14987717870571562,-0.5014010759651129,0.9881149741135105,0.9982583044955786,-0.17461715936527153,0.48894107935679826,-0.9823694867193146,-0.2558447481716523 +96,0.11654889579376129,-0.6314026488541591,0.9893550725234111,0.4925770360212569,-0.17432612874520476,-0.26703367167051356,-0.9791745877129031,0.20669740843472967,0.23150921120056775,-0.577513611899127,0.9656541622438166,-0.30350736069204554,-0.287903257198445,0.19217011893962613,-0.9488443019336659,0.42390191064619953,0.34331485790796995,-0.5182197396634947,0.9300819655407683,0.9677874842080605,-0.3975531028643983,0.49568907231733483,-0.9055781749588148,-0.43771952908559814 +97,0.18245311135771747,-0.6063894714071458,0.9738108641265286,0.5010131014721029,-0.271759609253581,-0.25223212913927967,-0.9488881815841421,0.1999468215579197,0.3587810768062985,-0.5427278489264217,0.9159844317403779,-0.4506542315368754,-0.44278581983888515,0.16699289815192386,-0.8753804425423758,0.42466899173299805,0.5230656571181169,-0.5613610581695603,0.8285573168186818,0.8650850196530041,-0.5989425126558304,0.5084380861008372,-0.7724823983759875,-0.22685582942110713 +98,0.24754672449437742,-0.4018892111146257,0.9515079929059833,0.4919528395938047,-0.3664777575780805,-0.23478092924399757,-0.9057009957764558,0.3201551013647192,0.4796840559379955,-0.6296642395371287,0.8457436656453174,-0.34914976938780695,-0.5853972213338666,0.1771608712959575,-0.7725763918268311,0.3411724013681728,0.6819634792080992,-0.5571775305557305,0.6882884304903586,0.8586155870769133,-0.7678705105127815,0.5360209029070137,-0.5913574077198357,0.0009586064712107799 +99,0.3115405373905749,-0.4417767598641747,0.922601250302425,0.42332809082026224,-0.45753418129029866,-0.34142732078644444,-0.8502001894302686,0.11695991801382473,0.5920719446895983,-0.6625706415000948,0.7565093201350219,-0.2816807122396853,-0.7117851953493901,0.13964120462995278,-0.643877840792121,0.3620968971678646,0.8136735694422572,-0.6149640392051635,0.5158344316161346,0.8603158607651058,-0.895181557781991,0.5626341497985637,-0.3734646776447228,0.09401528058680114 +100,0.3741502384241103,-0.5508682929275641,0.8872912614411065,0.39452838045875227,-0.5440190747037565,-0.35240473320554294,-0.7831403337253802,0.1493184417618101,0.6939496944066839,-0.5821538025409698,0.6502854062318185,-0.30013733960158423,-0.818447084107124,0.13438882311989533,-0.49359838833422687,0.3563992549116834,0.9129450621581827,-0.6824948349357929,0.3192594559810616,0.8939793545311113,-0.9739756661308704,0.5333031892976254,-0.13235172276043156,0.18587987972401324 +101,0.4350976653030245,-0.5511055264367002,0.8458230927079283,0.43684924631627825,-0.625068309351572,-0.42689864375331105,-0.7054331529729766,0.31536987620727547,0.7835088267047319,-0.6063018626350686,0.5294574840368029,-0.1435975227232545,-0.9024269094199129,0.05424853684096489,-0.32677496202536344,0.24072570156591916,0.9758203162106026,-0.6485945672032111,0.10775556237075606,0.7104291140121207,-0.9999823622584344,0.540290966348219,0.11699022120995153,0.10086022028810289 +102,0.4941120408846437,-0.5916740798357942,0.7984845508841715,0.44605186233877325,-0.6998720680724586,-0.343472628976518,-0.6181351291040311,0.2628361516210052,0.8591595365934789,-0.5982095560664391,0.39673908819293036,-0.17948101503396496,-0.9613972933022695,0.0962640268014636,-0.14899899503835154,0.21412330540525726,0.999792693634785,-0.6416656637688058,-0.10878709805403952,0.6908465751345122,-0.9717921384770268,0.5943899966626933,0.359058279858428,0.13336413167891345 +103,0.5509311761838867,-0.56878348962909,0.7456041856446624,0.457645050484816,-0.7676829364361141,-0.37022465786111486,-0.5224331380794807,0.29270666097479603,0.9195589138301378,-0.4419706098881167,0.2551107877469714,-0.2174243634154547,-0.9937239578459853,0.07211879436065304,0.03377098233441966,0.13859128904758664,0.9839064913920381,-0.7361223946826716,-0.320242755938462,0.6420422577677065,-0.890932845034972,0.6065453339129789,0.5788018337358242,-0.10560112703561415 +104,0.6053026352261208,-0.4684938070217103,0.6875490092828691,0.3643839131099978,-0.8278233706611389,-0.4304929491341024,-0.4196283135049542,0.2067287482856787,0.9636347815313934,-0.3884245879640937,0.10775324899443954,-0.2958070479950553,-0.9985110168412102,0.07699486759729186,0.21540905518345918,0.14689106973793076,0.9287950422339709,-0.7242343508832662,-0.5167235156765589,0.6287806991468289,-0.7617868838785207,0.5739168508194704,0.7625582929726507,-0.1687271280653043 +105,0.6569848565693539,-0.44765407364540066,0.6247219494890134,0.35565751737284235,-0.8796924674085422,-0.4142000746978778,-0.31111835683516864,0.31933201753291346,0.9906047288740973,-0.4076693919652235,-0.042024195428285886,-0.2654338825718293,-0.9756258039519613,0.04197000899966841,0.38982724672728647,0.03825681932024829,0.836655465724281,-0.7788903327667097,-0.6890417241201084,0.4695570334106547,-0.591353691775289,0.5959981919917001,0.8989025709435087,-0.3928923624924224 +106,0.7057482265130745,-0.5235670222497188,0.5575590528600193,0.32127343219609317,-0.9227719678092682,-0.47237773971257857,-0.1983785346728031,0.39601254055364643,0.9999900000265173,-0.42594585875572205,-0.19085786702048183,-0.311645104876202,-0.9257025493695171,0.11794595179177927,0.5511795694824891,0.07114888822418904,0.7111610760151618,-0.8100239663772392,-0.8291395951916803,0.5283077060012916,-0.3888703857094099,0.6118566883284332,0.9793574408169485,-0.29998962087949815 +107,0.7513760992256747,-0.46288044873963086,0.48652645855011,0.3361829866418602,-0.9566314357355024,-0.4458669508354063,-0.08294162151646221,0.470830729655385,0.9916239927613844,-0.3801085646822633,-0.33540528238454537,-0.31612013787970916,-0.8501248030506279,0.18646331279259695,0.6940579656604998,-0.03070954484195215,0.5573149384038688,-0.8275815377685178,-0.9304660006895888,0.5192438940810173,-0.16531113076213094,0.6039279137180887,0.9989206084610697,-0.43075790611573167 +108,0.7936657592582477,-0.3497267381961362,0.41211716306621865,0.3294085569920019,-0.9809325585761778,-0.49188818594234857,0.03362293935028854,0.44703476878686177,0.9656552158879896,-0.19951901584367474,-0.4724202181971357,-0.237008821985927,-0.750987091654764,0.15882992649876923,0.8136735694422572,0.04838946671029884,0.3812504129073991,-0.7668361014167681,-0.9882828081386148,0.6170504753092149,0.06720763632468776,0.629148878543633,0.956375730864423,-0.40151956410917333 +109,0.8324293221685019,-0.4176870321875085,0.33484759866192704,0.3808603042520846,-0.995432527544753,-0.4172126908695944,0.14973037350692284,0.4277563361846086,0.9225446530046,-0.2402696881828703,-0.598825614386582,-0.1679183061485999,-0.6310368718023147,0.21312991851762503,0.9060172157496263,0.05264772866764461,0.1899866365535643,-0.7674685793062817,-0.9998864409967778,0.7664279948821757,0.2960838922826761,0.6345646922752508,0.854368042453202,-0.37716053753727813 +110,0.867494569253495,-0.326167291288825,0.2552540490773125,0.33456892802165544,-0.999986463744347,-0.4882595300394491,0.26380212148119664,0.40395192413212444,0.863057579368428,-0.5042663212128222,-0.7117826782683283,-0.20765807146700535,-0.49359838833422687,0.2798399732058115,0.9679938157291611,0.09532341885123924,-0.008851307462160092,-0.7792506583327679,-0.9647343008228753,0.7840316153416478,0.508913030209363,0.6489873225140227,0.699239887226607,-0.22903423792373387 +111,0.8987057126826491,-0.3249842174937003,0.17388892750100785,0.28984478732900393,-0.9945488657497844,-0.53592796443106,0.37428730033793556,0.47464227955408445,0.7882499771456274,-0.4996786574479457,-0.8087546377085748,-0.145226151925367,-0.3424805477300144,0.2841130330121799,0.9975260950751376,0.16578964940912505,-0.20733637778128405,-0.7609314412466066,-0.8844701397673241,0.8032090989076603,0.6941601629128784,0.6581316235734843,0.5006363825256179,0.16242913181502072 +112,0.9259240876317014,-0.19853529810911663,0.09131694258702432,0.3394640032358659,-0.9791740642427964,-0.6316968713707293,0.479683788986502,0.47883030704325574,0.6994497901904098,-0.4597321786869101,-0.8875637115574354,-0.20830632236582106,-0.1818713598376071,0.27466070766438366,0.9936242181867788,0.16655150034342664,-0.3975556010060438,-0.7867415771147342,-0.7628471969411419,0.7466613337973158,0.841785289295784,0.6729088121238037,0.2709057323520994,0.0738462206753816 +113,0.9490287683425958,-0.22409332941895777,0.008111179135783107,0.39704927417816027,-0.9540156791578063,-0.824376224021167,0.5785586505378316,0.4286148464643785,0.5982333511082314,-0.5415490265735277,-0.9464400179185432,-0.14451340281357405,-0.016221873307991897,0.21521890495577053,0.9564189645616212,0.05780788727371706,-0.5719255369779418,-0.7890346897796908,-0.6055526928896939,0.6594923515965198,0.9437874410879401,0.6721614703715478,0.024331475851518816,0.2457842375514117 +114,0.967917105372268,-0.3589116244927078,-0.07515087935888216,0.43670050943011024,-0.9193250847623279,-0.7054383555044101,0.6695676140552493,0.3881845096181986,0.4863973990570174,-0.44166938880339307,-0.984061321880816,-0.12918158345394865,0.14987717870571562,0.24197471938499032,0.8871573454493772,0.0070298157084687,-0.7234946066059249,-0.8530428415766252,-0.4199418889997574,0.7138069799142281,0.9946383183313986,0.679564258223806,-0.22375559396992106,0.0054345099369842685 +115,0.9825051816434492,-0.34270547893646086,-0.15789135869151577,0.37466008838738213,-0.8754488980083742,-0.7213973222245108,0.7514733508313973,0.3823303050522717,0.36592718501117627,-0.47529622774995445,-0.9995827300639781,-0.10996931136717417,0.31182260400076867,0.10086026531024332,0.7881608076832424,0.12118968409057732,-0.846220229387787,-0.8907535717716114,-0.21469414751002797,0.7887695702175717,0.9915819114951113,0.6919210735086204,-0.4579306247009187,-0.0713196610213092 +116,0.9927281852713374,-0.4414247627226193,-0.23953600463855151,0.3586761673126902,-0.8228255152504522,-0.8053130780122855,0.8231622967128864,0.36473148280048723,0.23896123066642116,-0.6474580233998021,-0.9926556651024078,-0.19252481418275022,0.46512632206164967,0.05003061668083828,0.6627474255739261,0.15487204704616497,-0.9352097220263128,-0.9305305899781555,0.0005929248683984557,0.7413696417483769,0.9347838713372987,0.6931710402981107,-0.6636337471389501,-0.37231974377718535 +117,0.9985406975097498,-0.5201630552816429,-0.3195181685276059,0.34667207917873744,-0.761980731934101,-0.8642143122946055,0.8836597917617156,0.30017471650872035,0.10775336656612035,-0.6130223739442876,-0.9634356939412748,-0.2087336622065179,0.6055397446449408,0.020451982977268766,0.5151206887664823,0.06221129622600369,-0.9869153542725878,-0.8964712163453418,0.21585227143774122,0.8442394908761973,0.8273225309862243,0.7023402628673401,-0.828075337437701,-0.31797858166537857 +118,0.9999168945374642,-0.5131639240449992,-0.39728274002098546,0.329148782488845,-0.6935224890216052,-0.8983823601889185,0.9321433314194985,0.2026838236645254,-0.025367276672264887,-0.6305100752584313,-0.9125790341383396,-0.05420536512186793,0.7291715191862504,0.00840890793866164,0.35022861356560064,0.09655980001118751,-0.999275785735511,-0.8631796634834882,0.42101813198907045,0.6897303102958403,0.6750220668235242,0.7010566942024997,-0.9410312139722942,-0.18264343985276774 +119,0.9968506621882557,-0.6636846491950765,-0.4722899998182455,0.3170515840831408,-0.6181347986468624,-0.9999871206100143,0.9679537490158951,0.2692212412037075,-0.15803761381669842,-0.7889945625815947,-0.8412278166323203,-0.04327413388166692,0.8325953719430101,-0.007351360424068147,0.17359789990553667,0.12180618127101249,-0.9717982450182521,-0.93796433907982,0.6064967286526745,0.736566892307677,0.4861368404958234,0.681246109712436,-0.9954783274327061,-0.2545056213287627 +120,0.9893556231149057,-0.714838181072651,-0.5440193655383829,0.3539091121922582,-0.5365709096922384,-0.9589381083293019,0.9906041775874564,0.12592254523054852,-0.2879025532239214,-0.8206176605529616,-0.7509844359443075,-0.11012580980850259,0.9129450621581827,0.03124132730519289,-0.008851307462160092,0.1416731812766308,-0.9055781749588148,-0.9850492037799514,0.7636148813429522,0.4601567523829087,0.2709040300470878,0.6673640729371368,-0.9880314200142769,-0.4365860134938309 +121,0.9774650762664994,-0.698012767584844,-0.6119730047836315,0.30721100399838647,-0.449645781574731,-0.9398847907387718,0.9997866691648573,0.12973485918922673,-0.4126568031449828,-0.8556904938474865,-0.6438755638513672,-0.06292096666143432,0.9679938157291611,0.08270067752331135,-0.1910038449000885,0.187672640541427,-0.8032555607809535,-0.9591877331008812,0.8850255750538026,0.5602967061343761,0.040988796614090335,0.6574058068817075,-0.9191535043517743,-0.2285184827070033 +122,0.9612318489479014,-0.651251619670935,-0.6756792903086809,0.32066966835305294,-0.35822794144086306,-0.958431285315941,0.9953763815348183,-0.029579688877438157,-0.5300857940101242,-0.8943358365643495,-0.5223066347076768,-0.12196224070751299,0.9962160369939796,0.07893590485217875,-0.3667544922789231,0.262316300168903,-0.6689096822142588,-0.9881535671943186,0.9650515142383932,0.4822745747599837,-0.1911479425285062,0.639988299581993,-0.7931270756363296,-0.14985742639338848 +123,0.94072806211868,-0.6614330905129898,-0.7346960733149603,0.2564168011647922,-0.2632308061305301,-0.9339371876023681,0.9774332755548358,-0.06227658892579102,-0.6381049902629059,-0.8753603900971791,-0.3890078245747481,-0.13012288553995344,0.9968295883829464,0.08344989868650204,-0.5302126020940419,0.3444949073278242,-0.5078964854842457,-0.999992634555207,0.9999505992984282,0.6156067164774683,-0.4129248692154293,0.6503106033810196,-0.6177878468063815,-0.13169660790691542 +124,0.9160448099742079,-0.5619201104057032,-0.788613752151911,0.2180812408476596,-0.1656035556178786,-0.9542504760410663,0.9462012999445754,0.0346195706339305,-0.7347968939012778,-0.9167059766476839,-0.24697273735297695,-0.03863436558125311,0.9698174662170184,0.08568815029482849,-0.6758995367355524,0.3627849930839091,-0.32663505863802106,-0.9619166821863349,0.9880909112147178,0.38007368985253936,-0.6123221451105643,0.6633112489088094,-0.40403756186882234,0.024520612348077987 +125,0.8872917552325018,-0.5607851055238864,-0.8370581151271822,0.1919004908182071,-0.06632164911875064,-0.8592343049102897,0.9021050746372444,0.0023770316783192076,-0.8184450828588514,-0.9999807825346458,-0.09939117489189143,-0.15420589137512988,0.9159282719398917,0.1198524099147848,-0.7989322963920675,0.359180292874764,-0.13235172276043156,-0.9129214914021617,0.9300270218481637,0.3151577964906548,-0.7785328614126905,0.6396980245787202,-0.16516617825838867,0.138186819032617 +126,0.8545966419248491,-0.5943810122258328,-0.8796929376954162,0.19773205705684918,0.03362292137534937,-0.789288755163469,0.8457441177830354,0.05571690284866341,-0.887564679995588,-0.9495871843014311,0.05042249908260653,-0.19739744795766628,0.836655465724281,0.08515082287917616,-0.8951871829181393,0.2729690804317011,0.06720805864360312,-0.8872524267510923,0.8284740615588424,0.21913295345476982,-0.9025487503568688,0.6530254335035731,0.08397443834676752,0.19351256820424437 +127,0.8181047278548061,-0.6327198160164098,-0.9162223160000214,0.23373485147381684,0.13323154275332702,-0.7353719184566497,0.7778846948924476,-0.07560725127683154,-0.9409287118289015,-0.8090023298639981,0.19910379243217086,-0.03369586341295423,0.7341959774083611,0.03593834537543288,-0.9614380141313472,0.26175324306509345,0.2640884668368108,-0.8831895425480949,0.6881807567694185,0.1318156396180887,-0.9776484141922787,0.6272454083238493,0.32789393109910414,0.511398372032127 +128,0.7779781392470327,-0.5639280690675578,-0.9463927205721706,0.24872251924238278,0.2315089585966079,-0.6937707929203941,0.6994494009359331,-0.09582919725690656,-0.9775898890993684,-0.8167739051820099,0.3433136438460132,-0.028647857300324642,0.6113893217996845,0.0466437487468204,-0.9954642560001775,0.25836606979188537,0.4504405012365868,-0.8371960722463517,0.5157073743739092,0.11716780724804954,-0.9997616106629585,0.6616753554489513,0.5514265673441451,0.703703917253922 +129,0.7343951504531283,-0.6138531641622795,-0.9699947559334741,0.2260515826529221,0.3274732134522769,-0.7012390799640944,0.6115046170382894,-0.014985212691185162,-0.9968974225342699,-0.8673838100608994,0.47981341095937835,0.1651526573538476,0.47163890567684835,0.01923176334424164,-0.996125448448122,0.1627302623328458,0.6188348942992564,-0.7932857272939471,0.3191189565518283,0.20271489697942013,-0.967689851492229,0.6870880036478486,0.740674170422903,0.8522261373197957 +130,0.6875493919145926,-0.6888013414240189,-0.9868646138899216,0.3032796498053424,0.4201654642073833,-0.7304592510398912,0.5152460123025732,0.050254526650784295,-0.9985085753046877,-0.6613532190988834,0.6055376032789312,0.11390156987597497,0.3188177073502529,0.010411902965584996,-0.9633994302387822,0.4655184583726534,0.7625582929726507,-0.8249200841873588,0.10760819066000465,0.23809106640906208,-0.8831713579015043,0.6968875582599984,0.8838702409818106,0.5829172589078465 +131,0.637648989901771,-0.7164175641007546,-0.9968852104306327,0.27190481745416206,0.5086595605323664,-0.7179959845625419,0.4119822878760787,-0.018723611028264375,-0.9823947471032186,-0.4903200967170218,0.7176627264904547,0.23980508392067254,0.1571609427348777,-0.03195081336589785,-0.8983830817556795,0.5225982927238109,0.8758808988970734,-0.8043161378310241,-0.10893445078547817,0.1857740151368449,-0.7507868527123509,0.7101456052605639,0.9721115409124114,0.7174853411447022 +132,0.5849156418507524,-0.5959227852152886,-0.9999869983408401,0.2592936312563225,0.592071298669006,-0.6741528894486046,0.30311738426887647,-0.006940709382290968,-0.9488419818410697,-0.6975650448384572,0.8136706920540431,0.28934101760664366,-0.008851307462160092,-0.0192778644186205,-0.8032555607809535,0.5057853338678439,0.95428489738448,-0.8075400651705451,-0.32038319931726406,0.1788071437374569,-0.577711294899767,0.7014093213394843,0.9999116535748116,0.6942051911250943 +133,0.5295836314063154,-0.6275320025457827,-0.9961484498892272,0.3381722132817015,0.669567256102097,-0.6904512003352132,0.19013139382847485,0.010428558882996087,-0.8984458899521787,-0.6647317095015265,0.8914053680440264,0.2973326844923104,-0.1746182566234609,-0.06937709860016532,-0.6812052635020888,0.44316490649995754,0.9946445684333027,-0.7990049731299939,-0.5168504824182005,0.3445513559365065,-0.3733250120565157,0.6972515457889525,0.9655421026113206,0.6920134392932413 +134,0.47189878754687226,-0.5639755249628372,-0.9853962062395694,0.3118474865920208,0.7403731188417227,-0.6583441338146857,0.07456043787851079,0.046831649127665784,-0.8321010754408373,-0.5524027466095175,0.9491210012236917,0.26616461519499907,-0.33554591629092423,-0.086166521989574,-0.5363229588000349,0.5152238285132361,0.9953508993211305,-0.8035378136057207,-0.6891492771222292,0.0718529432416093,-0.14870530656294534,0.6943388736776431,0.8711398202346337,0.7603171970873271 +135,0.4121173924158334,-0.499646542810522,-0.9678048925496896,0.3352328921039035,0.8037814181126971,-0.6460981249653535,-0.0420242178945775,0.028984124039896524,-0.750985255358137,-0.7563903927266687,0.9855214228993429,0.22280156481654015,-0.48717441183429194,-0.13081603552672624,-0.3734646776447228,0.5481581759966825,0.956375730864423,-0.8330460137374684,-0.8292227051593992,0.1368221245649074,0.08397391067204463,0.7202269213209556,0.7225742831329085,0.5267106036672927 +136,0.3505050427116589,-0.5895346223204567,-0.9434966000410308,0.2947360997415459,0.859158599148701,-0.6397708801800386,-0.15803752586635586,-0.01644086470424731,-0.6565383536096343,-0.8274151560668076,0.9997891580747996,0.10420808619481577,-0.6253015813511476,-0.1120752783195442,-0.19808895314429317,0.5097932043213641,0.8792728800362922,-0.8130072418284752,-0.9305207813107916,0.1292107451522122,0.31210192032287537,0.7250147876166235,0.5290825768376454,0.7149346241984571 +137,0.28733546969521884,-0.5917210928669255,-0.9126400386334936,0.3208317307393739,0.9059513514620855,-0.5773951205602964,-0.2719022062740284,-0.0890451691668222,-0.5504369402095979,-0.8413605476122751,0.9916037841774327,0.11072644105516326,-0.7460994324357614,-0.05272926042822714,-0.01607386648539956,0.40272309124968886,0.7671161941871404,-0.8672246267150759,-0.9883066978121199,0.08211584622575385,0.5233146688475238,0.7519465198590994,0.30269507848443405,0.7362979659941171 +138,0.22288932305693154,-0.6546255271043835,-0.8754493660266203,0.2925155956593109,0.9436921373393631,-0.5211138995020674,-0.3820701914067154,-0.12306943969224497,-0.4345644697222834,-0.6918936984034708,0.9611491270583582,0.13439201248376825,-0.846220229387787,-0.16058749269719585,0.1664799691506045,0.38582592515637165,0.6243770064508815,-0.8834926068606455,-0.9998783226155739,0.14461692079547803,0.7061648739262473,0.7789537220568848,0.05748746623084797,0.5708007879167993 +139,0.1574529240467105,-0.8025787443433129,-0.832182701353802,0.26679907736629627,0.9720038633234919,-0.5266893231056408,-0.48704367263800635,-0.13666693268851163,-0.3109778472012334,-0.6292296199127334,0.9091091326601242,0.1311608346870219,-0.922889270809022,-0.10096729607339966,0.34345388329287624,0.4487181853429214,0.45674587780301035,-0.8944285775141385,-0.9646945540112608,0.3111606256832997,0.8507424428473884,0.7799539185078793,-0.1912944382369544,0.5589383488530875 +140,0.09131699340632397,-0.8098553055113502,-0.7831403337253802,0.24943339013671623,0.9906036480070516,-0.6669795934555631,-0.5853954641521542,-0.12854093835737648,-0.18187091512987733,-0.5510108058575802,0.836652507065346,0.0707542445058575,-0.9739817863929329,-0.11643592681185111,0.5089162281138078,0.4555847713907267,0.2709057323520994,-0.897511696875171,-0.8844006231292648,0.3000514564370045,0.9492115784817415,0.7607981248296823,-0.4281825810547376,0.49792784638333454 +141,0.024775359755680612,-0.7781129288730496,-0.728662638094113,0.24586600704958803,0.9993056484897552,-0.6110520233696793,-0.6757884065051065,-0.1537820082308193,-0.049535509523345955,-0.5563998809423977,0.745406469877665,0.041430940945441405,-0.9980818218248938,-0.22446014585611926,0.6573211899099233,0.48382631201221027,0.07426543024478444,-0.9142851836447353,-0.762751161149929,0.20481375561147894,0.9962354628424042,0.7483997022025621,-0.6384484048511662,0.5839867169626435 +142,-0.04187634582854418,-0.7007086205805161,-0.6691277129078065,0.22047942031018683,0.9980229172592789,-0.5888743832082391,-0.7569935462158768,-0.14079725724609224,0.08367922349754048,-0.48762009301165044,0.6374202103775293,-0.013955537635429673,-0.9945214798842913,-0.10636387669232894,0.783694668551774,0.4510224995274084,-0.12533560020827272,-0.896926662156866,-0.6054346286829522,0.15710552183188856,0.9892655012834192,0.7515584501580547,-0.8090185991085457,0.6279208561864791 +143,-0.10834200324224282,-0.6886860763756487,-0.6049487559447875,0.2455226981015693,0.9867682709420627,-0.570798758459354,-0.8279068442242921,-0.18193576726630362,0.21540852847011638,-0.7715969133669576,0.5151188671487545,0.027306276184705092,-0.9633994302387822,-0.1324763688783936,0.8838009947674983,0.4259794307155154,-0.3199398958003875,-0.9328733078828125,-0.4198073171876728,0.21323524147062425,0.9286794509159497,0.7515730663156197,-0.9292879352912621,0.33429852867932497 +144,-0.1743263189558588,-0.7678159886809648,-0.536571196545048,0.2245674176737984,0.9656541622438166,-0.6475248363199618,-0.8875641860523059,-0.2638233932552994,0.34331401844225534,-0.8548371718431959,0.38124906469420167,-0.056350355312008604,-0.9055781749588148,0.010978728750045672,0.95428489738448,0.4653988846090967,-0.5017891973756624,-0.9079078127862242,-0.21454936081630455,0.26881663423859686,0.8177609469752724,0.7457998008641991,-0.9917786485905381,0.23648814708079482 +145,-0.23953613794406584,-0.7640904112412368,-0.4644696041406137,0.2880936370403567,0.9348915563592576,-0.6608048908030826,-0.9351544895955247,-0.23312084594307667,0.4651251847453019,-0.8922364781749585,0.23881723017083611,-0.10075981921340288,-0.8226601455347005,-0.06220650117594421,0.9927839622024905,0.44533312084829857,-0.6636337471389501,-0.9010877948610226,0.0007411560611398164,0.40512448952230734,0.6625215367660003,0.7639351134825983,-0.9926053691222754,0.20639258190782514 +146,-0.30368174611450327,-0.663951669384535,-0.38914439454123173,0.35681102834570094,0.8947878230776185,-0.6317902491145484,-0.970030732335616,-0.3204887099366053,0.5786797018286458,-0.7593816019439492,0.09102207551723347,-0.15739796959953833,-0.7169432938324188,-0.060382325996264724,0.998007813205949,0.48343143563947705,-0.7990213136211992,-0.8558020198809594,0.2159970156783608,0.3597866206327004,0.4713748665676806,0.7470187993425715,-0.9317166954079665,0.2767567696888857 +147,-0.3664781574491409,-0.6847252807163223,-0.31111835683516864,0.36522148616231387,0.8457436656453174,-0.6088476152304406,-0.9917187480511526,-0.3008744597896294,0.681961811686446,-0.7331964173060141,-0.058817238720538026,-0.07419253478249246,-0.5913574077198357,-0.09765396372005014,0.9697813621926953,0.5161685881530024,-0.9025544217869346,-0.8443710246844056,0.4211526208801004,0.4037577148844063,0.2546806798784033,0.737861872717856,-0.8128983902696688,0.1646023677117719 +148,-0.42764638013975387,-0.7126963872703047,-0.23093302500992663,0.414998003163392,0.7882491170715743,-0.5235894785947633,-0.9999236734296302,-0.2914236751860201,0.7731381064970291,-0.6732835583391705,-0.2073356445790858,-0.065352244955548,-0.449382916282607,-0.13212512282079192,0.9090506772182689,0.5925227736987173,-0.9701055333312049,-0.8649243786213304,0.6066146733478919,0.32303149919533464,0.024183341322962116,0.7449489927076653,-0.6435380004337727,0.5451433699323006 +149,-0.48691465609234674,-0.7225208949546024,-0.14914491947534733,0.4372602010507678,0.7228786438805687,-0.5333591281965481,-0.994533956934582,-0.27448182808167504,0.8505900742810576,-0.6230481031972923,-0.35119773884951894,-0.12107550467826146,-0.2949544348243706,-0.16143392270627194,0.8178512731626093,0.5855597036031585,-0.9989815986065973,-0.8237200849356715,0.7637107666206521,0.5726735437873901,-0.20762468296282718,0.7499309525157579,-0.43416553465655655,0.3897858907369526 +150,-0.544019668293626,-0.8539859797665557,-0.06632168457455162,0.5568232117508372,0.6502854062318185,-0.49576913415748225,-0.97562287542453,-0.20393358977226453,0.9129428298455946,-0.46718645905584555,-0.48717268904279,-0.09828586045163659,-0.13235172276043156,-0.1803191943023353,0.699239887226607,0.4596340057292032,-0.9880314200142769,-0.8439546849752012,0.8850949172145435,0.604187732843301,-0.4281798904606801,0.7409808689172951,-0.1977987587809895,0.6232691296364428 +151,-0.5987077106754591,-0.8209442934310748,0.016961851109947905,0.5817759127723208,0.5711947317597765,-0.5115733385655228,-0.9434475379042923,-0.20625009304602693,0.9590895209980197,-0.4547038304667201,-0.6122067909221609,-0.03305969520890151,0.03391892365258597,-0.252401260823134,0.5571920260442581,0.3775949670992866,-0.9376915466194324,-0.8735963412715544,0.9650910707676266,0.7347835170895425,-0.6255286574194023,0.7510448615654495,0.050866185892860204,0.568476728421622 +152,-0.6507358152798379,-0.8523072073167836,0.10012766431183594,0.5776381837357942,0.4863968683398661,-0.4607826315025696,-0.8984453899533372,-0.20843753253208747,0.9882109767839498,-0.4815595574948016,-0.7234920481169532,0.026021677072948387,0.1992495567242226,-0.32763084594736985,0.396468718331474,0.26938184397465337,-0.8499688703176599,-0.7891201669484059,0.9999585204900007,0.8447466318769923,-0.7889750998707605,0.7442150580445197,0.2963685174942595,0.6543150813884291 +153,-0.6998728317165822,-0.9099277309619971,0.182598548809052,0.5851708855204515,0.39673908819293036,-0.3685670535726247,-0.8412282663558485,-0.16633689734850354,0.9997902489643353,-0.32276926177326337,-0.8185292336673461,0.10391374525209426,0.359058279858428,-0.37103308585713907,0.2224569391329161,0.37937742064076124,-0.7283606173881955,-0.737986271219398,0.9880668266651327,0.8300816110729989,-0.9096607678207603,0.7931310188364457,0.5234440901154529,0.8634343544390212 +154,-0.7459004541179705,-0.8340145859852982,0.26380212148119664,0.6068029054537983,0.30311722222127896,-0.3835115639889803,-0.7725740727910418,-0.2983440738773591,0.9936217885993456,-0.28553042539082,-0.8951840172738977,0.17485154965341512,0.5089162281138078,-0.3533781642235184,0.040989054179228386,0.4243236755352852,-0.5777149251183075,-0.7757409281874582,0.9299720577770455,0.8115237307478039,-0.9810447544225287,0.7786809119145721,0.7179744444735174,0.7265598664566326 +155,-0.7886141910277368,-0.7864634163951715,0.3431747948925702,0.6004343417739685,0.20646670916242466,-0.36683779652890813,-0.693416209676795,-0.31291312288513606,0.9698150948413159,-0.24039107214336644,-0.9517348977547095,0.20462383820190466,0.6446703077154994,-0.37572788662387147,-0.14185266394844032,0.41898874233034844,-0.40403756186882234,-0.7942193858658511,0.8283907881456931,0.6058714768466846,-0.9992581992334705,0.7682694848138871,0.8678646253269284,0.8931932851677002 +156,-0.8278242739154351,-0.6955142750858299,0.4201656888294076,0.6932699713877991,0.10775324899443954,-0.4070761538188358,-0.6048308819566539,-0.3363216186183452,0.9287927711653704,-0.2564685356156593,-0.9869118642506486,0.12098920193974173,0.7625582929726507,-0.43242264444822864,-0.3199398958003875,0.4233283517412277,-0.2142524960418839,-0.8176377533638589,0.6880730679692388,0.7962120728583675,-0.9633139722331534,0.770228395348414,0.9637951872115138,0.5712162325967661 +157,-0.8633565002797943,-0.719685549407444,0.49424045364355607,0.6549394578901109,0.007963153980990558,-0.3159915881961962,-0.5080224673622445,-0.37976086241207424,0.871283023542456,-0.3596171415358265,-0.9999249179312569,0.133430186582505,0.8593130908588155,-0.4556961839851684,-0.4873036783966224,0.4618658192893034,-0.015925859310604355,-0.8113350091462656,0.5155803058316397,0.7894114605438078,-0.8751601741736009,0.744455563738127,0.9998016324214156,0.5099241033475956 +158,-0.8950530075953472,-0.7039076259052606,0.5648849788669922,0.7140157984417536,-0.09190650623476952,-0.30081917458179547,-0.40430714207953605,-0.4143902465179485,0.7983067337231403,-0.48931687902738313,-0.9904818136653462,0.18696448757461714,0.9322532837150231,-0.4489658231950374,-0.6383344674024947,0.46797445196962717,0.18303569117443722,-0.7597932889241052,0.3189784501301393,0.7993634958635755,-0.7395745536566136,0.7256565417211259,0.9736452545878657,0.469934718988808 +159,-0.9227729746638389,-0.6449653938271711,0.631608961357793,0.6946383125234756,-0.19085786702048183,-0.3217543258334557,-0.29509498643913845,-0.22542564850046892,0.7111593371002062,-0.5271097806414337,-0.9587946232166036,0.10020400177552667,0.9793574408169485,-0.38449901411348997,-0.7679701526429006,0.2725195901804569,0.3747001862548749,-0.7601991311392019,0.10746081659137031,0.7675893857307187,-0.5639055633082671,0.7031665607537941,0.8869523305981969,0.5475179248721241 +160,-0.9463932472544476,-0.7124324144483571,0.6939493082130944,0.7258802304253542,-0.28790223908814844,-0.4077310326452898,-0.18187081391593907,-0.24464077669695092,0.6113878268443625,-0.5630881491894426,-0.90557497256898,-0.03478096543562162,0.9993201393698201,-0.337015646874104,-0.8718657252250657,0.31060922751696723,0.5514265673441451,-0.7276599616534423,-0.10908180112997323,0.7922348880479891,-0.3576740892558896,0.6964619611180283,0.7451130065757331,0.32643488003481413 +161,-0.9658088852532292,-0.5355895065144407,0.7514733508313973,0.7289609314458407,-0.3820699871506753,-0.370891739697894,-0.06617398408085352,-0.26407374223063057,0.5007632925231493,-0.4771048857846092,-0.8320180603412612,0.008929159424877934,0.9915881423911858,-0.3312218971596861,-0.9465389095244047,0.23962774458031552,0.7061693113205663,-0.7358062982239041,-0.32052363567590875,0.7763182474547686,-0.13205743932387398,0.692273387932789,0.5569461647005892,0.20531477604562917 +162,-0.9809336288909083,-0.5533049927129073,0.803781847817241,0.6657058477545351,-0.4724202181971357,-0.3830223333337613,0.050422526038662596,-0.17372533558085118,0.3812494806825246,-0.5183669412596974,-0.7397758162027859,0.05270954465169002,0.956375730864423,-0.3850032825272123,-0.989486878875324,0.1277196582727222,0.8327593133007577,-0.7983072928587566,-0.5169774378347507,0.795489735797058,0.10071644330797519,0.6809163946217288,0.3341511078292741,0.14051742023295816 +163,-0.9917002819776596,-0.6877697856467679,0.850511755886796,0.5979563010481965,-0.5580501825847355,-0.3159282965656539,0.1663335076502882,-0.22220399901750074,0.25496793548660607,-0.6392493992559233,-0.6309198020944293,0.08844014482858606,0.8946587652548067,-0.3816431067564638,-0.9992701429985067,0.2070892797388379,0.9261498293836297,-0.7882255282515086,-0.6892568150238886,0.755160388940742,0.3280317086506715,0.6800762942747486,0.09058015350772429,0.02271867111630931 +164,-0.9980610104422466,-0.6128354336313908,0.8913387495431321,0.5585230495695993,-0.6381042940148105,-0.2550631128153829,0.2799830721876314,-0.18209496568010203,0.12416033821991551,-0.6403395187769441,-0.507894689413418,-0.05086508192112607,0.8081476409645214,-0.371520585884713,-0.9755607955035421,0.19184835645864154,0.9826176744037675,-0.7724463428860785,-0.8293057969574021,0.7532272848777432,0.5375683521772873,0.6814827191810051,-0.1586226360410919,0.33216591411579655 +165,-0.9999875548491668,-0.652992032006204,0.9259794720335266,0.599978372126126,-0.7117826782683283,-0.31390620147751985,0.38982607658625396,-0.24413634738766404,-0.008851285819143564,-0.4877524045662919,-0.3734633569641671,-0.12653213105909591,0.699239887226607,-0.3739069895571498,-0.9191535043517743,0.34598927634562066,0.9999116535748116,-0.7634439674615098,-0.9305755415426644,0.8540711770303312,0.7179699328986444,0.7276349000603903,-0.39796303835712543,0.5272570294486983 +166,-0.9974713559496301,-0.7773172969660384,0.9541935019662018,0.5701017543310525,-0.7783491652846606,-0.2794438176042866,0.4943691305536854,-0.3008659417616971,-0.14170578664804576,-0.3175319605580715,-0.23064484265672702,-0.022841799380026964,0.5709537230901615,-0.3763806900043454,-0.8319388769112056,0.29193722017001117,0.9773423105215168,-0.7620423155621154,-0.9883305658301034,0.9208429578083062,0.8594590683419905,0.7413265882709479,-0.6125600264507652,0.4193072596701785 +167,-0.9905235927085793,-0.8284621454150561,0.9757850219373507,0.5680386876944907,-0.837138644729493,-0.30765714875589034,0.5921909002253344,-0.24585821412288233,-0.27204480345000664,-0.358802348403232,-0.08264654242399118,0.026465974885453533,0.42684441189542627,-0.3705058333276816,-0.7168400923288917,0.24729337479086677,0.9158094137297402,-0.722454615320051,-0.9998701823252941,0.9977915487616215,0.9543673479122683,0.7624556364131477,-0.7890709989859699,0.20162453111489656 +168,-0.9791751326387994,-0.7755102661531552,0.9906041775874564,0.5301066060312014,-0.8875637115574354,-0.2768564242875551,0.681961432164296,-0.25598134702134595,-0.3975546289120927,-0.4791207489473424,0.06720782097621032,0.04912312833112173,0.2709057323520994,-0.30422332510597727,-0.5777149251183075,0.2596448551514561,0.8177660856162815,-0.7396522377243049,-0.9646547860615091,0.9075078016864845,0.9975509444574153,0.7445118771851542,-0.9165213586075023,0.3817606079142788 +169,-0.9634763946627743,-0.7325630759303934,0.9985481176544713,0.5057932286902381,-0.9291205351690502,-0.34477603059914097,0.76246023498214,-0.16123418254746608,-0.5160072808986288,-0.5985615643820933,0.21555284160875476,0.012993528960474785,0.10745929680581393,-0.42221984098086635,-0.4192264438443876,0.1901377298971006,0.6871210042794174,-0.7478729527991358,-0.8843310871124457,0.8630374189117807,0.9866693987736797,0.7339783260816132,-0.9869868592498809,0.39444187563121313 +170,-0.9434971251115717,-0.7505428278764106,0.9995617078054543,0.5368687633552037,-0.9613938935196727,-0.3833795223231589,0.8325928727487738,-0.004035575182070193,-0.6253000523778837,-0.580135965220611,0.3590570101231171,0.08198659241639113,-0.058965215924966195,-0.4155643456534302,-0.24668671870801795,0.16892745712820553,0.5290825768376454,-0.7728900001895255,-0.7626551086455092,0.7641879484452553,0.9223124675229432,0.7645753690398748,-0.9960862973772684,0.29270410324524043 +171,-0.919326087855937,-0.7229599212925478,0.9936379132923613,0.5294235610869673,-0.984061321880816,-0.49064412205382235,0.8914058445926608,-0.006836909810463561,-0.7234928375332997,-0.4053705625771066,0.4944975322710992,0.11153708202399196,-0.22375559396992106,-0.4278005538798227,-0.06587877649448187,0.2306700461745371,0.3499512966739757,-0.7094825568576197,-0.605316551210083,0.758627745987643,0.8079681596743896,0.737243939177728,-0.9432539139993165,0.1684936792713442 +172,-0.8910706699462959,-0.7430733268577483,0.9808178477761973,0.5312131554725424,-0.9968963348011504,-0.5656745209513335,0.9380995481931627,-0.0346841138410854,-0.8088425714663853,-0.41582302043022284,0.6188327059177012,0.09956546676977897,-0.38234491311059304,-0.36377821773531016,0.11713722956492027,0.31192492836146324,0.156868562647098,-0.6753345060443551,-0.4196727361768817,0.5780929638283581,0.6498336938274741,0.7677882394897065,-0.8317745708249755,0.22634676587095387 +173,-0.8588564045137166,-0.6856120844440625,0.9611904879806703,0.48078870234092536,-0.9997706890742684,-0.6355416982811356,0.972039150917484,-0.08274120631960244,-0.87983417213645,-0.43272609815361085,0.7292702311135444,0.14883015637781338,-0.5303381024031278,-0.35534220088770707,0.29622713839892695,0.4424457717845697,-0.04246802594515088,-0.6535086528571514,-0.21440456942142735,0.3836084771374969,0.45647962212773363,0.7752438843936666,-0.668579513471402,0.07837327962012283 +174,-0.8228264130514906,-0.7629376537907792,0.9348920561557704,0.5133916781327151,-0.9926556651024078,-0.6246820360978406,0.9927632208017394,0.006154408870383383,-0.9352074352726826,-0.28127616162910596,0.8233299191319133,0.21233235039282647,-0.6636337471389501,-0.3549477288213638,0.465388380228658,0.40647372551809535,-0.24011154835856524,-0.616943377546797,0.0008893872376411567,0.18561825333116916,0.2383853245457482,0.774449720200236,-0.463815420182426,-0.1447343566283398 +175,-0.7831407695551633,-0.7192112588765311,0.9021050746372444,0.5246746704548766,-0.9756223538531809,-0.698287671266355,0.9999900000320824,0.015647128014011115,-0.973979404834729,-0.3989408397087701,0.8988993921605662,0.2305315302269452,-0.7785377535525064,-0.3439414661752621,0.6189511651448041,0.44700487403085426,-0.4281825810547376,-0.699669696556786,0.21614175518610942,0.32191936250445746,0.007371048742645748,0.7589071798697787,-0.23021353051998736,-0.20665882402579874 +176,-0.7399757893460115,-0.5981128645431862,0.8630570990637212,0.49164497027630166,-0.9488409465421256,-0.6531107305936077,0.993621235633665,-0.030439756783913113,-0.9954618219135256,-0.39045450485288996,0.9542815227533997,0.23766740698331987,-0.8718657252250657,-0.4217027376945137,0.7517685178772217,0.38155314749976577,-0.5991833254525278,-0.709217804046541,0.42128710054293955,0.2537108970630501,-0.2240427222449998,0.7606579822167906,0.01770192144906925,0.002538746811184534 +177,-0.6935232457375787,-0.5724672682224896,0.8180191390435021,0.488754107641418,-0.9125790341383396,-0.6817197098569582,0.9737435152866017,0.09379191663220925,-0.9992733423289929,-0.27117392757879344,0.9882325476540885,0.32356274531035345,-0.9410312139722942,-0.3962267463785142,0.8593887893474554,0.3199851647595635,-0.7462965214968447,-0.7149784843660417,0.6067326047511222,0.13051828615191055,-0.44331385359019876,0.7359483330706535,0.26451675368028593,0.08831671524336411 +178,-0.6439895180254736,-0.6168707878821066,0.7673037772322717,0.51883521790853,-0.867198933683964,-0.6753391178378502,0.9406270901076131,0.054034018575335546,-0.9853463060444431,-0.26625686036552726,0.9999900000374283,0.2858840252794433,-0.9841173994713753,-0.4677782872867103,0.9382048629215747,0.1529921333677516,-0.8636572303039475,-0.6624785104311448,0.7638066351641228,0.15536489186621422,-0.6385583157416571,0.7428753456554406,0.4948852153337387,0.35911367629327323 +179,-0.5915946745857159,-0.5373270030805528,0.7112629998761175,0.5487510518512376,-0.8131540681420757,-0.7428364159479898,0.8947222004024376,0.12548076745709308,-0.9539279382307397,-0.2591215501563565,0.9892898328704534,0.27766787936455106,-0.9999302111317135,-0.4941911085167567,0.9855750544063959,0.18533375240341457,-0.9465866507668218,-0.6448925825968654,0.8851642399813108,0.32878270580096003,-0.799194273581802,0.7525561324155817,0.6944841112353274,0.26943015975143925 +180,-0.5365714951552691,-0.5511802009332026,0.6502857538768221,0.5463793402705781,-0.7509844359443075,-0.7619295263013105,0.8366529543428967,0.12305096580683318,-0.9055759606595722,-0.2016000424610222,0.9563723488395427,0.31355161834645506,-0.9880314200142769,-0.4973950611625626,0.9999116535748116,0.337564749526306,-0.9917786485905381,-0.6434340662364078,0.9651306061500325,0.3671918977946466,-0.9165155994126397,0.7358229867341827,0.850903348779433,0.48291249400590486 +181,-0.47916443663855113,-0.5042960737783164,0.5847952473344464,0.6534195179639094,-0.6813112155050336,-0.823466855520637,0.7672088427921657,0.1017935208034044,-0.8411486913300028,-0.13489346309067973,0.9019768036701457,0.3392919068029137,-0.9487507837054554,-0.5057333897453817,0.9807341395674377,0.28687727650770106,-0.9974315614334571,-0.6319446922457373,0.9999664197707405,0.35834803383338787,-0.9841637265837466,0.7237014084710526,0.9544175373798285,0.39761455531294454 +182,-0.4196285470346411,-0.47682937464338393,0.5152460123025732,0.6232998679658788,-0.6048305586119814,-0.8694721897670449,0.6873340056399617,0.11064989177947933,-0.7617898080743309,-0.3546592310658909,0.8273248040369273,0.3312881293636325,-0.8831769075680354,-0.6147544033188956,0.9286852865459446,0.29920031613614756,-0.9633200254993318,-0.6655412303749393,0.9880427204652819,0.2726709270717901,-0.9984722702980908,0.7276735357735584,0.9985906661521659,0.27169456512795154 +183,-0.3582283323103875,-0.43546833436486443,0.442120750141946,0.6675619690344573,-0.5223066347076768,-0.9010362867753404,0.5981143955791197,0.07418823427038891,-0.6689080466114544,-0.3851066251806441,0.7340928729114915,0.39335188778787794,-0.7931270756363296,-0.5665231686850866,0.8455096197837403,0.3899587486772397,-0.8908039600806766,-0.7459968224017002,0.9299170733286206,0.3628164891158306,-0.9586657378800326,0.7323699040646934,0.9806762638984534,0.0645892669879056 +184,-0.2952365812536673,-0.3647565157581919,0.36592698136711643,0.6925331698714323,-0.434563995560995,-0.876014559036492,0.5007630138407391,0.25030996236315006,-0.5641521934368313,-0.3902826301082068,0.6243747984707934,0.5783539113862357,-0.6810968872480894,-0.49929348943759816,0.7339949442827645,0.3362414400241003,-0.7827743518680612,-0.7519533185137608,0.8283074965810604,0.3566422228172831,-0.8669015588310748,0.7247137783633437,0.90178816138383,0.02344234788813011 +185,-0.2309331535277489,-0.3167041274452501,0.2871935232375599,0.76067316821491,-0.34247933661842295,-0.9449208799622723,0.396603418618098,0.1266036335453885,-0.44938181746167843,-0.43181942196325057,0.5006346121286421,0.4814121475553296,-0.5501910951602332,-0.5224963060217537,0.5978789057061601,0.3265222452362543,-0.6435380004337727,-0.7476624908012768,0.6879653640921747,0.36677303595758576,-0.7281531567336006,0.767380552433972,0.7668312387343396,0.32734614745013363 +186,-0.1656037363111622,-0.34502064826926737,0.20646681954029744,0.7673675574804435,-0.24697273735297695,-0.7988398333863977,0.287051730393562,0.2225333984370652,-0.32663425995737383,-0.2726827750841868,0.36565125170226676,0.5503810327783901,-0.40403756186882234,-0.5607467507332466,0.44172371543087957,0.3104440069286825,-0.47864581972378417,-0.7099420513637784,0.5154532259921075,0.4667811974014035,-0.549940400257087,0.7586756989770775,0.584196463747095,0.29827639302714604 +187,-0.09953857504710924,-0.33992355763827037,0.12430714803691101,0.7486802539297343,-0.14899846813423626,-0.686755266214585,0.17359737881824086,0.2530354198082239,-0.19808846878167158,-0.3125064769312221,0.22245615245998135,0.6346228946150279,-0.24668671870801795,-0.6591311148511917,0.2707632385700804,0.32751134206879495,-0.2946715406356436,-0.7229342603411069,0.3188379367190802,0.5243132837450064,-0.3419220422496351,0.7908452358503261,0.3652391821536459,0.3746404527279033 +188,-0.03303118394141441,-0.3542197554752525,0.04128473189685996,0.6534197450757225,-0.04953545547424012,-0.6042455949198662,0.05778285290369017,0.15501419270311875,-0.06602631888182521,-0.2738786299776591,0.07426516762045728,0.4897306343961041,-0.08249931407193935,-0.6172285092417864,0.09072757012725888,0.361313389651743,-0.09894963711217238,-0.7089237159881514,0.1073134401680858,0.5299490028125907,-0.11537223585037017,0.7807022591139973,0.12357309722111383,0.40951418052665767 +189,0.033622958061977536,-0.23868199588516725,-0.0420242178945775,0.6911612504545314,0.05042249908260653,-0.6479931577286316,-0.05881727016445322,0.09489244636973049,0.06720789430798035,-0.20486276241326587,-0.07559365277762403,0.5523064259566913,0.08397443834676752,-0.678569466058801,-0.0923490210236156,0.33745021820136295,0.10071707618931759,-0.6599669224783776,-0.10922914908430306,0.5208163634973746,0.11743050066256167,0.8211209818961193,-0.12577616438005726,0.4728620755356516 +190,0.10012772003443934,-0.26614671793055633,-0.12504150168634226,0.697438216992314,0.1498766486960855,-0.7431985779043407,-0.17461773247336926,0.10885210410017937,0.1992490695237194,-0.24021206163128891,-0.2237548027045622,0.6271326388697341,0.24812096207186515,-0.7421543902393448,-0.2723303431762871,0.4697779305432528,0.2963685174942595,-0.6662024696192816,-0.32066406501132216,0.40307479031655064,0.3438687534810956,0.8063187949176538,-0.3673052732670287,0.3218676466898526 +191,0.16618763471273568,-0.1832084479918421,-0.20719094411383734,0.7379870707935955,0.24783328037842636,-0.7976789812712686,-0.2880441480416971,0.2149961716908297,0.3277532836623877,-0.17492903829289258,-0.36689090214948894,0.5960997100376236,0.4053911763653313,-0.7351986813756235,-0.4431839489854308,0.4698749729807766,0.48020468125164417,-0.7073421712758955,-0.5171043819234213,0.5041184372447447,0.551670050294203,0.8267261809861431,-0.5859971192764365,0.23025271095903577 +192,0.23150921120056775,-0.27681687197781213,-0.28790239300175924,0.741944511614534,0.3433136438460132,-0.7704370743645386,-0.39755440766687794,0.3201393249850457,0.45043939982967546,-0.21375631307701168,-0.5017874229019954,0.6762873478691019,0.5514265673441451,-0.6610742144513835,-0.5991833254525278,0.36905309057601093,0.6448965997410208,-0.7460742148106392,-0.6893643378227279,0.4320522057223851,0.729572002218414,0.8106034853059807,-0.7682545026401588,0.26435592603900343 +193,0.29580223888950874,-0.2234228062494285,-0.3666156764634084,0.790587479572948,0.4353637308686965,-0.7256780663361395,-0.501659644947771,0.31017183111391133,0.5651295608410086,-0.24243953249762332,-0.6254148799256174,0.6841182947123903,0.6821799778154783,-0.6585428647002151,-0.7350998298874049,0.35534687129454123,0.7838785258876318,-0.798217762208775,-0.8293888705838742,0.45018757556427236,0.8679327013521737,0.7971029523072551,-0.9027455420099297,0.2648273484625302 +194,0.3587810768062985,-0.30304443684416343,-0.44278449073256754,0.8341030893721426,0.5230638074046525,-0.7466265580464106,-0.5989444784351692,0.2921183482067915,0.6697878493574049,-0.365106036043419,-0.7349968670605935,0.718968777385184,0.7940277683168202,-0.585411808406148,-0.8463779387137614,0.4341395260708944,0.8916096888788322,-0.8263257064898584,-0.9306302813840046,0.4739627272725633,0.9592532917030901,0.7859599059372201,-0.9811082359545006,0.23202420900535412 +195,0.42016592265815217,-0.26984329839739535,-0.5158801917446711,0.8764099849158177,0.6055376032789312,-0.7179667062001966,-0.6880862547044857,0.2520157958099333,0.7625564283823844,-0.2560476245345448,-0.8280724091206144,0.703944559024269,0.8838702409818106,-0.4636230465571613,-0.9292879352912621,0.4337263548121031,0.9637951872115138,-0.8291993015340791,-0.9883544121920417,0.44243522099838856,0.9985843912539538,0.774218315125716,-0.998470371707938,0.2186172528175455 +196,0.4796840559379955,-0.31399192251194713,-0.5853954641521542,0.8282337393022485,0.6819610675853659,-0.8412914842644552,-0.767873030727915,0.1816031395870493,0.8417885205735445,-0.38054814486168254,-0.9025512300899773,0.6406500095336508,0.9492175431313883,-0.48425968812193715,-0.981050919105283,0.4514953195498216,0.997557212861661,-0.7880416316341179,-0.9998620201261161,0.34986262454521544,0.9837943379547703,0.7689569612354813,-0.9537524557612311,0.2233534148923774 +197,0.537071049566752,-0.31450305994358246,-0.6508478423093741,0.7704059271319672,0.7515706023308427,-0.8639133246168571,-0.8372200510635122,0.2336390650615425,0.9060776409379818,-0.3125728717228705,-0.9567606962165225,0.8345695960986332,0.9882586698978062,-0.4556482864196094,-0.9999319463818837,0.4575895043729863,0.9915497804088362,-0.7376830082214636,-0.9646149969744904,0.43570564039918036,0.9156847212778731,0.7728112811261312,-0.84973483154024,0.11823502383663392 +198,0.5920719446895983,-0.32498686911377705,-0.7117830587900028,0.729671151288476,0.8136706920540431,-0.9414360915468778,-0.8951844958426122,0.27394742233758973,0.9542825639886631,-0.2913618475218785,-0.9894833797597202,0.7062154041402812,0.9999116535748116,-0.4643795156987278,-0.9852981803271714,0.37481959488389205,0.9460123872273839,-0.7317783960370636,-0.8842615317183934,0.5455529062867026,0.797946937771229,0.7829255815434659,-0.6928848111178505,0.1620993433020221 +199,0.6444423834068088,-0.5139145752238674,-0.7677781971967729,0.7364238712428328,0.8676408531860084,-0.9420020508198257,-0.9409782990466675,0.25318003768341185,0.9855475821583667,-0.09253760764748979,-0.9999843997931115,0.7365921657015352,0.98385354876958,-0.4060655600299774,-0.9376401019651358,0.4478127658429462,0.8627604654818977,-0.6689275396114137,-0.7625590394299897,0.7403526933819582,0.6369621252066716,0.7785734939181045,-0.49295456907319835,0.040121514397034735 +200,0.6939496944066839,-0.4677377440415345,-0.8184446273816793,0.7725123662381725,0.9129418337168758,-0.9166797877790558,-0.9739788628003345,0.2735343542274845,0.9993176958548495,-0.050777382206222084,-0.9880279260456021,0.6892153448888411,0.9405293823617579,-0.37156287940366245,-0.8585550706487349,0.4954345493912623,0.7451130065757331,-0.6614191880371216,-0.6051984604736681,0.5841829412978584,0.44145531846722413,0.7731006714167619,-0.2623747995102293,-0.010828352242706781 +201,0.7403739266773363,-0.5138164106485335,-0.8634307027049929,0.7930642222275178,0.9491210012236917,-0.8965242083130336,-0.9937375220126885,0.27440073829083084,0.9953484655116557,-0.16893902035723407,-0.9538824751409333,0.6837298282990034,0.8711398202346337,-0.4164887270584241,-0.750693785247781,0.44814412062500525,0.5977602434374559,-0.6613888977847695,-0.4195381459703299,0.7782196087761387,0.22202257118986288,0.745695068337471,-0.015481835705404115,0.012688787451261228 +202,0.7835088267047319,-0.478939729308828,-0.902424200612962,0.8374810597556926,0.9758168654239697,-0.9049382049237824,-0.9999856442840467,0.24692475182611517,0.9737103507463053,-0.09909834363385213,-0.8983148802931981,0.6325634444546345,0.7776078925768971,-0.3619703138609784,-0.6176714404780257,0.3104351305261738,0.4265766657348627,-0.6273257450427164,-0.21425977332857601,0.617415204699545,-0.009443327833880081,0.7044405882873123,0.2323737136585305,0.14172827337195282 +203,0.8231627548155148,-0.5759667565596028,-0.9351544895955247,0.9017995413845665,0.9927626900671027,-0.9283358649003866,-0.9926382821459993,0.2245683955230295,0.9347874595994989,-0.17785767460652668,-0.8225730698860527,0.6718376000408841,0.6625256999145963,-0.4070096551423921,-0.4639465561512684,0.36842036411392376,0.23838682250992616,-0.6508805650177836,0.0010376183946509156,0.6441017812544829,-0.24039741854819335,0.6989823407024248,0.4657813909909695,0.026723137044714688 +204,0.8591595365934789,-0.5511589167815552,-0.9613944074844338,0.8868807913037107,0.9997891580747996,-0.8993684057679644,-0.9717953279799708,0.172477062169891,0.8792707300581611,-0.20503233662739,-0.7283580416915835,0.6386023347041929,0.5290825768376454,-0.4628064285623604,-0.2946715406356436,0.32102664729856617,0.04069324894464566,-0.629133245981949,0.21628648995782263,0.7843893414258646,-0.45832247982047253,0.7275702431734352,0.6702290374070955,0.09920865329739657 +205,0.8913392455870175,-0.48274523448140144,-0.9809618380556229,0.920865677253616,0.9968260633013697,-0.8829015393399409,-0.9377401559124097,0.12949358659855809,0.808145664900214,-0.2530342991757413,-0.6177856621274931,0.6063512042991235,0.38097670425803437,-0.4418848644978943,-0.1155199972015296,0.44095626302936397,-0.1586226360410919,-0.5958469764485311,0.4214215709746444,0.7496769408515128,-0.6514074362380903,0.7076918762534439,0.833005088478837,-0.013541135456314274 +206,0.9195589138301378,-0.5085007003220058,-0.9937209749934839,0.8648824045651194,0.9839030120103199,-0.8828021774157268,-0.8909357691509632,0.13841613583459597,0.7226748375854438,-0.48386074230377996,-0.4933391484588829,0.5919471196291681,0.22231261973676972,-0.3890369451182721,0.06750343856070244,0.450121682943994,-0.3516147370905449,-0.5488937183735413,0.6068505228597841,0.559280588805679,-0.8091874933071129,0.7407556130204135,0.9439889177435442,0.026065880612794853 +207,0.9436931670203406,-0.607980407696213,-0.9995832644446337,0.9287675429854978,0.9611491270583582,-0.818725890537002,-0.8320185051412194,0.11354561501161226,0.6243754797383672,-0.5477970436288564,-0.35781330109198134,0.4540394823885981,0.05748746623084797,-0.3707102243597731,0.24826435655790655,0.46874050183934973,-0.5305890681566583,-0.5188616426106809,0.7639024869712551,0.30582604340893454,-0.9231113070849405,0.7246387957918351,0.9962800882389792,0.17107504546657812 +208,0.9636347815313934,-0.6918722367361146,-0.9985080196194357,0.965722278559312,0.928791757742489,-0.9110296902083629,-0.7617893841266835,0.1382771281774754,0.5149925480779199,-0.4671576207923793,-0.2142517383822645,0.45263027106133913,-0.1089308682902638,-0.34896411834123736,0.4207041796729208,0.3512097168817629,-0.6884104874459999,-0.5561487511203201,0.885233543352591,0.16995342920141535,-0.9870044488575602,0.7177134947242643,0.9866273882519041,-0.09332387181634119 +209,0.9792951607863104,-0.6679373159397465,-0.9905027031756746,0.9711209619532603,0.8871542082009943,-0.9987915212082173,-0.6812032187342151,0.08218873500850418,0.3964677488951438,-0.6098827383046554,-0.06587854352782034,0.39256780527582746,-0.2723303431762871,-0.4218125425618471,0.5790432295439465,0.27817786969927843,-0.818787153005781,-0.5305911695615797,0.9651701203847454,0.23222596206077753,-0.9974040460603334,0.7266403691912515,0.9156309759154228,0.035345209387149974 +210,0.9906047288740973,-0.6892761811924516,-0.97562287542453,0.9999918914087125,0.836652507065346,-0.9478305794014092,-0.5913556326474839,-0.016264646561890268,0.27090506993953856,-0.6627213186516584,0.08397414138854217,0.4134140092667525,-0.4281825810547376,-0.43783327394728944,0.7179744444735174,0.210356607063165,-0.9165213586075023,-0.4925940069969338,0.9999742971404736,0.20317987202954177,-0.9537464626153084,0.7629284121582889,0.7877050642830827,-0.18400092112729874 +211,0.9975132396615125,-0.609937804290293,-0.953971808718312,0.9628840937424251,0.7777912506396496,-0.9857765836671346,-0.49346816485842154,0.06145282916917917,0.14053343157324344,-0.7011619387571374,0.2319409481268253,0.3660880018481415,-0.5721683651842594,-0.45388473083479836,0.8328412566097173,0.1534582695996758,-0.9777167498919384,-0.5845551074606459,0.9880185926156925,0.2447508388261329,-0.8583978468093861,0.7393730066734668,0.6108034669410987,-0.24055395053025827 +212,0.9999900000265173,-0.6814467695243744,-0.9256997706962817,0.9118643955079324,0.7111585611411689,-0.9296153754783923,-0.38887166201448686,0.006736540328126095,0.007667120189453631,-0.7386252263007321,0.37469886120519863,0.30096262635251503,-0.7002973402270305,-0.47348215697081003,0.9197936664767834,0.19065291607026902,-0.9999336597218027,-0.6187074070434971,0.9298620685040927,0.27721387470505743,-0.7165258910832557,0.7761264505851783,0.39592506840323277,-0.18616441418764873 +213,0.9980240062216347,-0.5702110685566197,-0.8910029813631342,0.8854374860182198,0.6374202103775293,-0.9607634858373308,-0.27898818465407005,-0.1412927638522742,-0.12533529374050187,-0.5927638683695321,0.509041845663715,0.19306799814192963,-0.8090185991085457,-0.4396642100309936,0.9759172836935449,0.14871587885769585,-0.9822863700116937,-0.6451517789282568,0.8282241868667674,0.1306114092725207,-0.5358197540781281,0.7834824796811155,0.15642996674395915,-0.5811589046061376 +214,0.9916239927613844,-0.5545353813517195,-0.8501222512384518,0.889363335886226,0.557312967574908,-0.9554950147465985,-0.16531167332706573,-0.07447473605004847,-0.25611282372694705,-0.6254850654785763,0.6319528476977322,0.13756750364857515,-0.8953190912117935,-0.42536932395292154,0.9993310088123568,0.11124664718187453,-0.9254784225131113,-0.6664413227526771,0.6878576451405959,-0.039594594780879656,-0.32607332457041566,0.7711454452638586,-0.09279119259122257,-0.5043816263615976 +215,0.9808183936165933,-0.5196716881442881,-0.8033413100289714,0.8686536390589332,0.4716372378234493,-0.8771325999378975,-0.049387637457823146,-0.14387067234175982,-0.38234397820947147,-0.6148806651023316,0.7406715511819586,0.1151525661295014,-0.9568071246666158,-0.43705025693633226,0.9892500822690169,0.16392687311258836,-0.8317745708249755,-0.731551889778235,0.5153261348580912,-0.07326763155297292,-0.09865441350383979,0.7769120395815542,-0.33624304499791974,-0.7647765479483626 +216,0.9656552158879896,-0.49492883815660776,-0.7509848374233878,0.8619358751974233,0.38124906469420167,-0.8133043280178602,0.06720785690576209,-0.19490975111430436,-0.5017879704120822,-0.39266562573595953,0.8327563684197596,0.09081109434954032,-0.9917786485905381,-0.5214985975728861,0.9460123872273839,0.1957376113714279,-0.7049104918144901,-0.7017615383900627,0.3186974163217264,-0.16452736718149183,0.13411135721659903,0.7928695612567384,-0.558788933433369,-0.659344274893933 +217,0.9462018265203238,-0.666304641388301,-0.693416209676795,0.8409027239736516,0.2870515769347184,-0.8961781195914922,0.18288961478125398,-0.19668925966507292,-0.6123244955708953,-0.3086416874767901,0.906139272939061,0.06009150604602519,-0.9992644783657527,-0.5365054359881183,0.8710671247242782,0.2883354833606611,-0.5499438559690574,-0.723555290600769,0.10716606139337687,-0.13671077197989995,0.3596085772502678,0.7904012741313343,-0.7465920324381476,-0.6838632937910131 +218,0.9225446530046,-0.7560329160363098,-0.6310349776244631,0.7876110214008742,0.1899859647051359,-0.8498106821019342,0.29608486405473383,-0.1650579036298051,-0.7119913694307172,-0.5561773599074953,0.9591722429085168,-0.010853908540354524,-0.9790571551727493,-0.5682643113836634,0.7669262406906555,0.34897091994224455,-0.3730526940343257,-0.7526585801630095,-0.10937649464522486,-0.3645168710801117,0.5656157762641705,0.7973783134505826,-0.8879756549256671,-0.6202757860075052 +219,0.8947887993981406,-0.7965888028485396,-0.564274093628605,0.7033263069663895,0.09102207551723347,-0.8238136630479551,0.4052546383865218,-0.1772311226731414,-0.7990193598723806,-0.46332259380440466,0.9906642721549075,-0.013814541345782042,-0.9317166954079665,-0.547187047862277,0.6370802328732516,0.3079413348827095,-0.181289098424304,-0.705389911598884,-0.320804487320434,-0.4453553947676881,0.7409678016646034,0.7769069533107957,-0.9741492520301002,-0.88419649936033 +220,0.863057579368428,-0.7561183593682059,-0.4935969067042674,0.7049325880217174,-0.008851276161342752,-0.7534126622183732,0.5089147005044228,-0.22056744936680878,-0.8718635933587418,-0.47572363456963696,0.9999081175941491,-0.00416068869425817,-0.8585550706487349,-0.6162599997498184,0.4858811595622658,0.22421100816259606,0.01770192144906925,-0.662427934348862,-0.5172313146814341,-0.41613480323670726,0.8761609461511297,0.7558818771860684,-0.9997549668585279,-0.5424091368565301 +221,0.8274919683353151,-0.6428022508736627,-0.41949394667951106,0.7074240339689672,-0.10863618881435774,-0.7607406383035582,0.6056557214181084,-0.2209904780310367,-0.9292309786891924,-0.5841958109553774,0.9866961823862243,-0.0006430798479338025,-0.7615998482805871,-0.5261560357765572,0.3183967713303756,0.1768543253797331,0.2159872215759115,-0.6650862494621198,-0.6894718455163934,-0.32163060222234147,0.9638680284648419,0.7811832547376729,-0.9632007600821836,-0.4406755136322458 +222,0.7882499771456274,-0.6524594173291111,-0.34247951970915363,0.7426881997576302,-0.2073356445790858,-0.7164056738000313,0.694162441204496,-0.18788453491502732,-0.9701031612511273,-0.31296228826495937,0.9513251781126607,-0.004941384067946012,-0.6435380004337727,-0.6056826307339235,0.14024065486026568,0.05604795725440266,0.4056617927656072,-0.595677249911845,-0.8294719260369872,-0.19638246682072358,0.9993355110484666,0.7860741887340983,-0.8667593952309963,-0.4161663007073992 +223,0.7455059500627972,-0.6516597893488985,-0.263088138770635,0.7023011970249347,-0.30396347111923505,-0.7864535089475843,0.7732315508602453,-0.13765608587046835,-0.9937546002969656,-0.30571427896975817,0.894589461274079,0.0664602264114635,-0.5076414384820207,-0.6137413239041724,-0.042615918057195085,0.14559824794374857,0.579163908417558,-0.5815621917160112,-0.9306850008336166,-0.1333939487935336,0.9806411316846843,0.79727897724113,-0.7164271292652535,-0.20004062452708823 +224,0.6994497901904098,-0.689956408587904,-0.18187081391593907,0.7305068987120813,-0.3975541951329225,-0.787380992812912,0.8417880521053933,-0.15467421036207343,-0.9997654483487667,-0.31999792050200804,0.8177631937557307,-0.08120892401723825,-0.35767633680604444,-0.5773918304768862,-0.22404413008322513,0.1986158089782926,0.7295765866982961,-0.5396574159615726,-0.9883782368974129,-0.10474759650992872,0.9087980860251031,0.8314484557294043,-0.5215508943602077,-0.0717627762352843 +225,0.6502861157709641,-0.6254422537177807,-0.09939122802678423,0.712956764281032,-0.48717268904279,-0.6797946260772141,0.8988998727155338,-0.09326410526727993,-0.9880290041022272,-0.23891306832633485,0.7225717278984704,-0.12418153302949728,-0.1977987587809895,-0.6016771627555849,-0.39796303835712543,0.3339216262154872,0.850903348779433,-0.5109351140903798,-0.99985383601822,0.06425410865318756,0.7877001145381402,0.8289274845704195,-0.29424715093454945,-0.4435025071799735 +226,0.5982333511082314,-0.7373724287990674,-0.016221824614936854,0.7328470002641662,-0.5719235144816223,-0.6762830126708941,0.9437905386771288,-0.19727672917108377,-0.9587536065301835,-0.19389724097334332,0.6111528588048277,-0.17185553080851623,-0.0324394775607966,-0.5248905189298503,-0.5585433899099405,0.3507777411481352,0.9383072795255256,-0.4724151916163982,-0.9645751867510747,-0.06879293422794876,0.6239104700465702,0.8206129832187511,-0.0486485448266961,-0.6806242610931467 +227,0.5435227561520647,-0.6013083980094209,0.06706016517995395,0.687720624206052,-0.6509598692165303,-0.6564555458405403,0.9758497308794016,-0.07976810109513131,-0.9124589365625174,-0.3332398944880627,0.4860088140698168,-0.06670864114422564,0.13381881631696257,-0.5553985022324505,-0.7004030011692479,0.39465020524220307,0.9883038600367794,-0.4973219894410031,-0.8841919569486304,0.018371214775720186,0.42630620336595004,0.8110360175178235,0.19997479217308345,-0.6507800433380467 +228,0.4863973990570174,-0.5171949835281928,0.14987672882070088,0.6998886990897788,-0.7234920481169532,-0.725095324414154,0.9946415828174126,-0.15293072675068903,-0.8499667919930296,-0.16425718396405078,0.3499500591436264,-0.12937105565316537,0.2963685174942595,-0.6544841353948084,-0.818787153005781,0.3455565286177644,0.9988998844215471,-0.504544369122226,-0.7624629535054732,-0.18011432551861453,0.20559704600671688,0.8124500586500436,0.43616466515770647,-0.7906963256499492 +229,0.4271110762794862,-0.3797482264289919,0.23165308402618737,0.7280462462183894,-0.7887953336272641,-0.6744983584163717,0.9999106064934012,-0.19532215423323127,-0.7723864993721168,-0.018623076260781978,0.20603218033663998,-0.1472580348745244,0.4507047988336424,-0.58253074529561,-0.909727955007746,0.4825282274031021,0.9696729226252964,-0.5605798382489048,-0.6050803564763061,-0.14356497388284897,-0.026255034524455524,0.8455481230835835,0.6452359317920483,-0.9999683158444924 +230,0.36592718501117627,-0.36552940143564244,0.31182166800558914,0.7127955994182923,-0.8462172369050445,-0.6523157217315282,0.9915851659497428,-0.2726759691081855,-0.6810952218454073,-0.05733362147973213,0.057487262938318864,-0.16291873853548708,0.5925504561111989,-0.6298218539902996,-0.97017733755849,0.5999052895687899,0.90178816138383,-0.5387139937810904,-0.41940354657095985,-0.10797067119060126,-0.25668414795437966,0.8551827991579193,0.8141895533370094,-0.9515785817192471 +231,0.303117552958418,-0.3830839056139636,0.38982607658625396,0.6842281332811829,-0.8951840172738977,-0.69883633439414,0.9697784512084735,-0.3109255932893988,-0.5777135125028576,0.10758221137653873,-0.09234869445041148,-0.19419068926018188,0.7179744444735174,-0.6628362894215206,-0.9981092142123112,0.6133945187120781,0.7979519519050938,-0.5323151113761191,-0.21411497254091974,-0.18416720623297192,-0.4732015258041208,0.8445318108111722,0.9325208119184917,-0.7236956117534291 +232,0.23896123066642116,-0.3365524026297233,0.4651249258960422,0.6903737476839749,-0.9352064148506432,-0.8308211410018617,0.9347869393760198,-0.4341954644629979,-0.4640765472019998,0.18259978538013974,-0.24011069925375803,-0.1918680774591743,0.8235008215942224,-0.6254658140923302,-0.9925873901843565,0.7171535811021404,0.662303916186903,-0.5125540458253387,0.0011858495289246477,-0.14734713233546703,-0.6640723861124043,0.8346375519035834,0.9928724430060355,-0.4247860057294489 +233,0.1737432517539016,-0.2962049967379527,0.5371956097983421,0.7352132867863307,-0.9658845390683422,-0.7189119774551187,0.8870863638354546,-0.470769940848585,-0.34220154752457754,0.28587731088907276,-0.38248033539982557,-0.28143107396596456,0.9062050783297869,-0.5832059343079099,-0.9537969408607838,0.6947926755347933,0.5002519133512023,-0.5192625483857103,0.21643121999031506,-0.24365802502267817,-0.8189519336913215,0.8264420249477866,0.9914920744869923,-0.6154877951520632 +234,0.10775336656612035,-0.3991013939353041,0.6055379270015925,0.7848348993316667,-0.9869118642506486,-0.7730107150305804,0.8273252463278532,-0.5104734383907741,-0.21425197215653333,0.3383707119395907,-0.5162602877914844,-0.2972123096784259,0.9637951872115138,-0.611423250678349,-0.883038008613297,0.6173086440017623,0.31825644536639547,-0.5121687797050326,0.42155603217227483,-0.045307046227585746,-0.9294460260514841,0.8179607559704594,0.92846553099017,-0.6661919614463637 +235,0.041284754872455746,-0.45131003350147136,0.6696775526696496,0.8482197527386605,-0.9980782923150526,-0.7033280444458287,0.7563160798584979,-0.5157009412554323,-0.08249911234648516,0.26552518343086273,-0.6384461471104893,-0.30435254289359076,0.9946751226333268,-0.6538468636950066,-0.782682225830674,0.6033091037170336,0.12357309722111383,-0.49142415725776123,0.6069684276712939,-0.16137571125399267,-0.9895661180944233,0.8144690947851516,0.8077114977260982,-0.5260764971944196 +236,-0.025367276672264887,-0.5189064371198691,0.7291693304384873,0.7806298817272771,-0.9992722520034635,-0.5652832937191195,0.6750242823023124,-0.4914954877417831,0.05071822682078113,0.08216835286013115,-0.746293882373619,-0.33243567193862816,0.9979890923662418,-0.5925387634951199,-0.6560932247398279,0.5242588620192876,-0.07603672035291548,-0.49342391427599835,0.763998322039953,-0.0645436945508966,-0.9960538285362807,0.8357543330133976,0.6367378756204817,-0.3187745172970334 +237,-0.09190660651575888,-0.4579033211533678,0.783600361991259,0.7625762682542608,-0.9904818136653462,-0.5736285463329884,0.5845550708917716,-0.3551619167810877,0.18303524361971082,0.06397015692108426,-0.8373814659527854,-0.4061753464083979,0.9736452545878657,-0.5210274960826689,-0.5075138982945424,0.40482378122727214,-0.2726151938342136,-0.5259434181265721,0.8853028273268585,-0.2314334844943017,-0.9485575372520637,0.8662515229872183,0.4261749763985662,-0.3428855938767574 +238,-0.15803761381669842,-0.3219688785092033,0.8325928727487738,0.7851350745387872,-0.9717948084548418,-0.6060466180946368,0.48613843603752677,-0.4417404099245068,0.3121031183575844,0.24396524200812814,-0.9096632670939825,-0.4263330434970826,0.9223182631441286,-0.5386592820924623,-0.3419241908171777,0.3183672924482967,-0.4583253598243403,-0.5213229119836962,0.9652096134709015,-0.0915675302821888,-0.849651442339514,0.8915688665221083,0.1891145812891417,-0.31598075884033555 +239,-0.22346649182858055,-0.1627533830810707,0.8758068337862263,0.8257111622342458,-0.9433979507509122,-0.6970602196414096,0.3811124187729042,-0.4067840914645202,0.4356307083438057,0.2486386847079032,-0.9615159923738889,-0.5525935122991656,0.8454305705050417,-0.5903943293187606,-0.1648741848348143,0.22891349552440327,-0.6257635400475741,-0.5975410196754017,0.9999821525990276,-0.20950524924411068,-0.7046960440502352,0.884593384480506,-0.05970404252325249,-0.19391759053313595 +240,-0.2879025532239214,-0.3052508250905807,0.912942321778985,0.8650154178246391,-0.90557497256898,-0.7721679076614789,0.27090491917673515,-0.5326554254581114,0.5514252190083923,0.20169399880829042,-0.9917751413705649,-0.5475973367613293,0.7451130065757331,-0.6069289146036069,0.01770192144906925,0.30175593482794716,-0.7682545026401588,-0.5668459928404246,0.9879944431168955,-0.3068394916105448,-0.5215476170626268,0.875103699386991,-0.3048105581433814,-0.12340621695678702 +241,-0.3510595215595087,-0.2789034295324414,0.9437416005984248,0.7763973725072962,-0.8587037886044255,-0.7604532506839113,0.1570142831014063,-0.4963991985749585,0.6574311291981556,0.09596623460155,-0.9997611588323999,-0.4618589746030129,0.6241457258632798,-0.5809037173309524,0.1996847111166777,0.31146969363006377,-0.880117582579756,-0.5262859660692,0.9298070433046642,-0.37243895437508484,-0.3101324170823232,0.8561168652929093,-0.5309654295841015,-0.27775828033812267 +242,-0.4126568031449828,-0.26181293074978707,0.9679909101106019,0.7034076268981816,-0.8032527202338713,-0.8155123531172739,0.04098893114274343,-0.4618254379840683,0.751766679669869,0.0357472399928521,-0.9852946960240333,-0.49433990710420195,0.4858811595622658,-0.6681843803720179,0.3749746534873206,0.2075184015933644,-0.9568931518735341,-0.5282427110462286,0.8281408590046396,-0.3931843939017494,-0.08190869886793242,0.8385216309064362,-0.7241074423025479,-0.16242887881936732 +243,-0.4724207336640902,-0.43594259063285806,0.9855219497627712,0.705212497702951,-0.7397758162027859,-0.7411786657616208,-0.07559369319027275,-0.48187303552700145,0.8327572770565465,-0.05973790673721506,-0.9487006385122951,-0.43315320796307955,0.3341511078292741,-0.608209359604295,0.5376965425403477,0.2905353637686851,-0.9955204108530742,-0.5000439486364324,0.6877499111168577,-0.3202996316155427,0.15075429683768093,0.8672525634436127,-0.8722279614160239,-0.15923082330691618 +244,-0.5300857940101242,-0.5394118172926926,0.9962130466610261,0.67248088698902,-0.6689073167535775,-0.800335611564253,-0.1911485698920146,-0.5087081356649299,0.898965220337801,0.0923541308209729,-0.8908008099367958,-0.45056345706562506,0.1731605470479181,-0.593340188805178,0.6823964166318557,0.2510914025773328,-0.994459412598224,-0.5063606755788238,0.5151990324323873,-0.23600073964033919,0.3752467298813699,0.88711968815898,-0.966117570456129,-0.3058475081250469 +245,-0.58539578993384,-0.5200402991308077,0.9999900000320824,0.7056622641215812,-0.5913553165067262,-0.6752994154180928,-0.3041046518843542,-0.4283581552124006,0.9492152221261512,0.14451936759380193,-0.8128955156227154,-0.41684578766946356,0.0073710950607947,-0.6402174749360955,0.8042243593322289,0.29093615667373046,-0.9537524557612311,-0.5300386094340512,0.3185568889411502,-0.15969501631480026,0.5794015872506778,0.8800112465688289,-0.9999386682796811,-0.25613951722963174 +246,-0.6381049902629059,-0.5425003954167076,0.9968265962083009,0.6573902254521305,-0.507894689413418,-0.7288790083444402,-0.41292622446932964,-0.3945390611433886,0.9826152717292859,0.16059914194296263,-0.716734340526499,-0.39646307929478114,-0.1586226360410919,-0.6237964166753928,0.8990970544354528,0.33914495419151464,-0.8750223982529546,-0.5281338713299394,0.10701868027047284,-0.2314834694412968,0.7521541093343553,0.8804184602283035,-0.9715884228338368,-0.025832435943087755 +247,-0.687979218636011,-0.3261917479854887,0.9867447905617235,0.6544056697919666,-0.4193593464711572,-0.7279902755547275,-0.5161337844141267,-0.3102060168074406,0.9985724696385606,0.16922524506963466,-0.6044768573296145,-0.4138162773105634,-0.320220372869979,-0.5935961768060003,0.9638346467717017,0.29134733841846805,-0.7614079590191913,-0.4890725941482864,-0.10952383780951712,-0.18294920617560467,0.8841414763899657,0.8998783937217595,-0.8828295150682364,-0.09918587126246016 +248,-0.7347968939012778,-0.24513302206219612,0.9698145551244235,0.6601667137384574,-0.3266339035607231,-0.6284145324680801,-0.612324154802974,-0.36158208001358433,0.9968035523584958,0.3091255662281644,-0.4786441270918835,-0.3308461290236487,-0.47294367058748643,-0.6323355839838152,0.9962673216493237,0.2907214102774441,-0.6174385872210872,-0.4953487843999407,-0.32094490260015385,-0.050381479910089996,0.9682102534857631,0.9421425439146928,-0.7391805439709614,-0.35973231275237416 +249,-0.7783500145566823,-0.10947987032000399,0.9461533929527567,0.6483237542097788,-0.23064484265672702,-0.5827842858161874,-0.7001895621808903,-0.28194283331926545,0.9773399207462355,0.3137624789809425,-0.34206208165518653,-0.3271154817301286,-0.6125600264507652,-0.6106131987976866,0.9953080306950366,0.3041623192482656,-0.4488538873906183,-0.5183617241419209,-0.5173582361060138,-0.2968820183502662,0.9998040914789326,0.96124968779195,-0.5495729068123556,-0.10386204689460063 +250,-0.8184450828588514,-0.10056921940528357,0.9159255226060192,0.6905708190159823,-0.13235125472595005,-0.5575175674264125,-0.7785354166225499,-0.29082311511724174,0.9405270826006147,0.2344281836287177,-0.1977980593066073,-0.24822314503410212,-0.7352001774333713,-0.5896481693452128,0.9609889265350202,0.49724934932217446,-0.2623747995102293,-0.49790031430240317,-0.6895793381025297,-0.08991768160259525,0.9772106714607033,0.9676258108599989,-0.32579548812127446,-0.19004382727796326 +251,-0.8549039644945556,0.029903752234770854,0.8793407383995913,0.624803998276385,-0.03273525680745955,-0.6213224226340187,-0.8462965529941412,-0.28392207340715236,0.8870185193956938,0.23752903200875597,-0.04909191897331618,-0.17159395808517763,-0.8374653311888003,-0.5125510672288985,0.8944602851329572,0.5774774802460272,-0.0654356563449261,-0.5175145531581934,-0.8295549633149252,-0.21456356122982526,0.9016545088039585,0.959828736069216,-0.08176168394361567,-0.1353503786906766 +252,-0.887564679995588,-0.03742657732626939,0.8366529543428967,0.6639315241055529,0.06720782097621032,-0.6276435746837874,-0.9025517125972311,-0.3747998492125412,0.817764086033171,0.27889789471812354,0.10071672002421789,-0.11478165177302871,-0.9165213586075023,-0.47830778409012353,0.7979519519050938,0.5758572662563993,0.1341121999446879,-0.4885657367957037,-0.9307396998902961,-0.29240568258241917,0.7772305870311007,0.9581699467097061,0.1673556657353764,-0.22484968900942529 +253,-0.9162821243809185,-0.06674026482173183,0.7881584418678484,0.6929049393022351,0.1664793804286414,-0.6182419084917754,-0.9465360683068494,-0.4278108491055114,0.7339931495349514,0.24746030579661482,0.2482634786223801,-0.12346740607677235,-0.97017733755849,-0.4242362048664603,0.6746986038257905,0.4318397832986428,0.32831342603803426,-0.4494047646417985,-0.9884020399456962,-0.4049559741987855,0.6106824184055323,0.9767617694783794,0.4060676506929371,-0.06554146835331684 +254,-0.9409287118289015,0.024203716694722044,0.7341937735787256,0.6688776649083711,0.26408753294259263,-0.6758784077421804,-0.9776516229160086,-0.43581833405345427,0.6371927659886355,0.3238707810400973,0.39023477471497986,-0.05654342897675449,-0.9969462710905873,-0.5045453516012357,0.5288313325190398,0.42858078140254363,0.5094258318881653,-0.42603276926312034,-0.9998456300017843,-0.3829520253336366,0.4110365598913215,0.981823030112944,0.6195323158871154,-0.12494358957707953 +255,-0.9613949425153832,0.03733183140617771,0.6751334872947263,0.6091865375478093,0.3590570101231171,-0.6707571678594564,-0.9954753393141288,-0.49847677675884433,0.5290812831369119,0.2156773111875082,0.5234422390637385,-0.02143720611927902,-0.9960862973772684,-0.5226854594774414,0.3652391821536459,0.554693948530491,0.6702290374070955,-0.4579124217734183,-0.9645353553921383,-0.389785108296746,0.1891133929396125,0.9999960871536114,0.7944774623354411,-0.08409307459090438 +256,-0.9775898890993684,-0.06000329575234665,0.6113874865977117,0.6385979371100539,0.4504389083469587,-0.7462020729222283,-0.9997648919640457,-0.40561454875858766,0.4115778372033273,0.23932286560177518,0.644894319197637,-0.054056106661375836,-0.96762124933175,-0.49393213968607574,0.1894052829865774,0.40476598278404824,0.8043123262346372,-0.466751971104443,-0.8841223628046778,-0.4014350962497896,-0.0430593182017702,0.97675388660634,0.9200258481648044,0.029509716148383235 +257,-0.9894416006948833,-0.11168092375673382,0.5433981959264884,0.6998082964657955,0.5373201698939015,-0.8726498666012791,-0.990461961482367,-0.28242073905495796,0.28676828543018756,0.012511288983343836,0.7518634632320171,-0.17539998874430404,-0.912339994112663,-0.47654654950313363,0.007223072878819541,0.4559844714805124,0.906330220769519,-0.4146347197374258,-0.7623668508740649,-0.24678787983595638,-0.27289830576162316,0.9675711169799527,0.9883714828280601,0.5777347072365625 +258,-0.9968974225342699,-0.16527474647399007,0.4716374899624757,0.7663996026916836,0.6188327059177012,-0.8854003239616809,-0.9676930275311245,-0.23833319181188234,0.15686817907564163,-0.02118488768340095,0.841947374803662,-0.23087346864567956,-0.8317745708249755,-0.5175304429510847,-0.17520123349920907,0.3867991760886779,0.9722155894927249,-0.4143279074946784,-0.6049622392205793,-0.367081298355812,-0.487946784829653,0.9691627898331531,0.9952649657885488,0.34942009150093045 +259,-0.9999242299027156,-0.12547860699896732,0.396603418618098,0.8008819730930491,0.6941620701028643,-0.8492802070179101,-0.9317676492408153,-0.31885001008882896,0.024183434153157633,-0.0075090062579758485,0.9131229634680543,-0.25340890029315055,-0.7281577322977505,-0.5185183996455054,-0.3517533071184418,0.48448094554867255,0.9993417906665604,-0.4560116198679185,-0.41926893798173404,-0.15005660041100968,-0.6765495846383683,0.9426856151112875,0.9402776936638297,0.007146977163294211 +260,-0.9985085753046877,-0.16169825777964353,0.3188167503579329,0.8172796235081048,0.7625555963430352,-0.7890734239210294,-0.8831742565436466,-0.25585072610220155,-0.1089306019349719,0.0767542202041181,0.9637917789492614,-0.3787087463068442,-0.6043610676062894,-0.5557152549257511,-0.5165156390114066,0.5997523193608882,0.9866273882519041,-0.48569198704522903,-0.21397016706162447,-0.10039973660097501,-0.8284848340612185,0.954165291867221,0.8268285087080932,-0.11386848691735907 +261,-0.9926567482084496,-0.09958054630580518,0.23881735784341807,0.8281372386762353,0.8233299191319133,-0.8140601293384271,-0.8225735096366839,-0.3553674204940323,-0.2401109612432456,-0.05909290003012409,0.99281590888706,-0.4283359818026377,-0.463815420182426,-0.5783250359551799,-0.66396587785995,0.5674152451898657,0.9345792653504099,-0.4093243289851332,0.0013340806372143506,-0.08863242658300413,-0.935517965314699,0.936414561641438,0.6619711317596895,-0.06172120231117038 +262,-0.9823947471032186,-0.1273255444175156,0.15716047098625227,0.819354984791028,0.8758778015255105,-0.8142148964954122,-0.750789316857493,-0.4320719154658526,-0.3670290014043059,-0.1823744687258444,0.9995435338953585,-0.5641075115798541,-0.31041580700853816,-0.5722722715595743,-0.7891619230545719,0.5446055906527385,0.8452724163748147,-0.4236451717870244,0.21657594528042243,-0.11255143342711889,-0.9918480100406779,0.9636941722507111,0.45595559604354075,0.0036747620035962832 +263,-0.9677681639931142,-0.05392315403206523,0.07441282341595735,0.7773031286985801,0.9196742024529744,-0.7754926354004026,-0.6687976330903763,-0.4963820079625184,-0.4874317424213096,-0.16027266766492754,0.9838235660203686,-0.623808455402353,-0.1484134739207933,-0.6159562919832948,-0.8879075694477435,0.544046775450042,0.7222672235697097,-0.46807224582075585,0.4216904841328718,0.006279260229656684,-0.9944219994691983,0.950628144220682,0.22159094975044577,-0.21690892090056274 +264,-0.9488419818410697,0.023291407879657174,-0.008851280893268,0.7244760535981962,0.9542815227533997,-0.7382299520913037,-0.5777131909964762,-0.48739017609552554,-0.5991818603430317,-0.043017967843146786,0.9460090418503241,-0.6039620033566111,0.01770192144906925,-0.5728131075021685,-0.9568931518735341,0.5351136734962498,0.5704675158069115,-0.484391977127757,0.6070863191830737,-0.06446406498724573,-0.9431004288408069,0.9640884448605256,-0.026551148539808382,0.09366727336975204 +265,-0.9257002858626101,0.08781529682131493,-0.09205395353714756,0.6919085312506786,0.9793539775220065,-0.6897363650009739,-0.4787743454656908,-0.5270042123887773,-0.7002956278758357,-0.2353990332670686,0.8869491940748067,-0.5783234028815416,0.18332673396610968,-0.654899739555987,-0.9938064754408799,0.5520117770113765,0.39592506840323277,-0.4841086702635434,0.7640941403681166,0.06859029520308045,-0.8406648182681866,0.9931334946628024,-0.27304242501225534,-0.18911082560742415 +266,-0.8984458899521787,0.061272557365904215,-0.17461773247336926,0.6936677567507612,0.9946410510785969,-0.6773745039948267,-0.37332623734037823,-0.4928642265593052,-0.7889781284312809,-0.3785855536738455,0.8079703795493773,-0.6542706935977146,0.34387091428136973,-0.6237288082532072,-0.9974103135414998,0.4358187377913357,0.2055983379361198,-0.48439188401825495,0.8853720919025989,0.07797059288008035,-0.6926669602544497,0.9871162317527288,-0.5025572459569345,-0.1402098835585151 +267,-0.8671998799017084,0.13591482984016667,-0.255969589854493,0.6724874969460695,0.9999900000374283,-0.6518101888497072,-0.2628025053114708,-0.47788989635411894,-0.8636551185088495,-0.553126165856611,0.7108462921800559,-0.7011778053651275,0.4948852153337387,-0.7108422450340243,-0.9675838760660482,0.4417666358621182,0.0070750505385742705,-0.43412962208192485,0.9652490854076312,0.24917811902456927,-0.5071280243323874,0.9713880193945247,-0.7008254914444759,0.1783590433038968 +268,-0.8321010754408373,0.07765835749459722,-0.3355449090856859,0.6902638128685458,0.9953473794686317,-0.675491955507925,-0.1487057946261785,-0.4614954699990039,-0.9230009728100718,-0.4625403917879892,0.5977581295820428,-0.6361133747992225,0.6321844967224848,-0.6065058618232644,-0.9053268579351827,0.5098535153422181,-0.19173029679733813,-0.43302845289482667,0.9999899861462309,0.2162642092734222,-0.2941038267140396,0.9471383560401281,-0.8555198022671052,0.23383297144148335 +269,-0.7933054134888071,0.10320458127004684,-0.4127914035047448,0.6786506371444788,0.9807595769022848,-0.644089455683698,-0.03258732760469798,-0.491219946568623,-0.9659622160175942,-0.3679979756211201,0.47124560808368704,-0.6090595899444559,0.7519637097500518,-0.6321780918087443,-0.8127259322506333,0.5218052366996114,-0.382891962239873,-0.3828534182943108,0.9879702719694192,0.17236465383993282,-0.06513982641222865,0.9602678598915634,-0.9570220354275899,0.050291624765572224 +270,-0.750985255358137,0.16635332603607386,-0.4871729494871265,0.7098366289663778,0.9563723488395427,-0.6396070718247229,0.08397418628143136,-0.7107307920766205,-0.9917762235158499,-0.4114991307182967,0.33414992617302336,-0.6896339139361498,0.850903348779433,-0.74179455905897,-0.6928848111178505,0.6250623381295031,-0.558788933433369,-0.3355672889241857,0.9297519977315464,0.12409809870736224,0.16735461411352245,0.9492207934098709,-0.999021273686088,-0.08533367342972202 +271,-0.7053286209880268,0.0823976624579209,-0.5581733073726326,0.7954606898351773,0.922429364401863,-0.5513227965402728,0.19939401403745713,-0.6925139971775065,-0.9999847590883144,-0.42484428142401753,0.18954995730501162,-0.6808110816443433,0.9262614464760786,-0.7511123626487388,-0.5498202182967745,0.7013132797247675,-0.7124087532113248,-0.35369749095524844,0.8280575129964988,0.1338696918831837,0.3907787900427832,0.9525448777146422,-0.9789062078276909,0.07003470098115784 +272,-0.6565383536096343,0.007414192132074792,-0.6252997043888671,0.705227568713586,0.8792697706696899,-0.6147200126424507,0.3121029446674387,-0.623982535211518,-0.9904421092775101,-0.3458205521933108,0.04069310504139096,-0.7456773451525354,0.9759495634207083,-0.8119145156497758,-0.3883272603731776,0.7446052175093468,-0.837627084134564,-0.34302536175731047,0.6876421620233208,0.15828192412420555,0.5930235856295334,0.9882406312898194,-0.8979274952217354,0.060342763257559794 +273,-0.6048312185544522,-0.11973689807004091,-0.6880862547044857,0.67125021817605,0.8273248040369273,-0.6420103736001984,0.4205686236641493,-0.6673015974791038,-0.9633176700110243,-0.34412747542815897,-0.1090776266323372,-0.8373357179030392,0.9985906661521659,-0.7721829582251308,-0.21381870881502654,0.6859405393942053,-0.9294518664984789,-0.3954500708896268,0.5150719187177752,0.1370806310253102,0.7631277625538627,0.99417364902291,-0.7611200000040392,-0.09884597800145195 +274,-0.5504369402095979,-0.21905277594349684,-0.7460971928756291,0.6396118409692197,0.7671134814394504,-0.6076241875491211,0.5233163864099641,-0.671909912705966,-0.9190929395771664,-0.3733085197263535,-0.2563987099693063,-0.9595804959186276,0.993557289635034,-0.899919511122672,-0.032143579709033164,0.7080824198534638,-0.9842223360000716,-0.3660846545546305,0.31841635458044437,0.16273823925328307,0.8918720357554514,0.9737542635802916,-0.5769897496109039,-0.26155517018449753 +275,-0.4935971813986275,-0.18963496595498672,-0.7989298982439903,0.6380730131438741,0.6992374145094317,-0.5748635518851734,0.6189493072445122,-0.6052659748130983,-0.8585529713293221,-0.32389501376306135,-0.39796163104318627,-0.9999856221823422,0.9609889265350202,-0.9328257362985448,0.15060890814599096,0.7073754571618106,-0.9997549668585279,-0.3821506507471892,0.10687129680259598,0.26030599154570916,0.9722787391005914,0.9441228326950064,-0.35698507119139966,-0.09300016930637704 +276,-0.4345644697222834,-0.29867115038919834,-0.8462176892959353,0.6779697116692963,0.6243747984707934,-0.4709180837887278,0.7061671916183289,-0.6544577921568425,-0.7827724378459455,-0.35181057192253085,-0.5305871918382075,-0.9945770950585407,0.90178816138383,-0.9999889921636391,0.32831342603803426,0.6807308785985277,-0.9754305220975703,-0.37936798395463384,-0.1096711785739513,0.18759886470719073,0.999990000064903,0.962194857874149,-0.11478479007430807,-0.3461538827325484 +277,-0.3736010756292228,-0.3075947071390139,-0.8876323685886472,0.6434702817899297,0.5432736358367058,-0.42344278500584825,0.7837842531213786,-0.7173360989917558,-0.6930965539276435,-0.08278281573345812,-0.6512969081826581,-0.8668990007318169,0.8175956567694579,-0.9871757506010305,0.4950138384393182,0.7655285698661214,-0.9122187405695967,-0.3111628453481045,-0.32108531084741176,0.23521421966874864,0.9735039271673561,0.953663484404848,0.13455225333850804,-0.219386770103228 +278,-0.3109778472012334,-0.3130974630769137,-0.9228865005803804,0.6463958815424886,0.45674426261572276,-0.33524248565246284,0.8507452350551237,-0.7530071770977961,-0.5911171970759168,-0.21212892328805805,-0.7573799000821544,-0.9390859420931016,0.7107446847730562,-0.8558999831884421,0.6451228352433134,0.7305946104115967,-0.812639676527811,-0.2544242090891233,-0.5174851461943671,0.12650155011337869,0.894256009298271,0.9803115724814706,0.3755234893319841,-0.18195072952612906 diff --git a/python/cuml/test/ts_datasets/exog_hourly_earnings_by_industry_missing_exog.csv b/python/cuml/test/ts_datasets/exog_hourly_earnings_by_industry_missing_exog.csv new file mode 100644 index 0000000000..4d5a4b26e9 --- /dev/null +++ b/python/cuml/test/ts_datasets/exog_hourly_earnings_by_industry_missing_exog.csv @@ -0,0 +1,124 @@ +,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 +0,0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,0.0,-0.0,0.0,0.0,0.0,0.0,0.0,-0.0,0.0,0.0 +1,0.09973465873441716,-0.008193039837578154,0.14935054265784684,-0.008194815669663652,0.19847280172665502,0.00819052427043743,0.24715922057524256,-0.008089077522064248,0.2952268271106974,-0.008194360406357605,0.3432034574627649,-0.00819461596484854,0.3890550169554975,0.008194386472961343,0.4358505061494723,0.008185445103196303,0.4789512779113474,-0.00819276885829072,0.5221701726051636,-0.008183693061075874,0.5641280675148789,0.008194542812409741,0.6054198916439166,0.008169544057826118,0.6435763695987253,-0.008194726758847215,0.6809594564678602,0.00816459311277819 +2,0.19847280172665502,-0.016386079675156307,0.29534699410826415,-0.016389631339327304,0.3890331191656121,0.01638104854087486,0.4789512779113474,-0.016178155044128496,0.564081921015252,-0.01638872081271521,0.644791925790795,-0.01638923192969708,0.7166868012776797,0.016388772945922687,0.7849206506424798,0.016370890206392607,0.8406385789802162,-0.01638553771658144,0.8903257536678935,-0.016367386122151748,0.9311899709804021,0.016389085624819482,0.9639299341140939,0.016339088115652236,0.984468716416999,-0.01638945351769443,0.996500908903801,0.01632918622555638 +3,0.2952278700904258,-0.024579119512734462,0.43471058880134983,-0.02458444700899096,0.5640839138084358,0.02457157281131229,0.6809644645497407,-0.02426723256619274,0.782549256893179,-0.024583081219072815,0.8681964236416051,-0.02458384789454562,0.931169498191638,0.02458315941888403,0.9777085447461252,0.024556335309588912,0.9965082376193338,-0.02457830657487216,0.9958789026408166,-0.02455107918322762,0.9729604270787925,0.024583628437229218,0.9293181153254614,0.024508632173478356,0.862350043156335,-0.024584180276541644,0.777297789267253,0.024493779338334572 +4,0.3890331191656121,-0.032772159350312614,0.5643115208503803,-0.03277926267865461,0.716646462877443,0.03276209708174972,0.8406385789802162,-0.03235631008825699,0.9311137982811112,-0.03277744162543042,0.9863281337304184,-0.03277846385939416,0.998641005901436,0.032777545891845374,0.9758290014817222,0.032741780412785214,0.9083979252494543,-0.03277107543316288,0.8076966173381208,-0.032734772244303496,0.6748478137513148,0.032778171249638964,0.5157002565033916,0.03267817623130447,0.3346546700081299,-0.03277890703538886,0.14097937151035483,0.03265837245111276 +5,0.4789512779113474,-0.040965199187890766,0.6812392327245661,-0.04097407834831826,0.8406385789802162,0.04095262135218715,0.9480458581364921,-0.04044538761032124,0.9965047171593899,-0.04097180203178803,0.9848630458384661,-0.040973079824242704,0.9084490569035975,0.04097193236480672,0.7796562488057206,0.04092722551598152,0.597880119293297,-0.0409638442914536,0.38128344268970554,-0.04091846530537937,0.14099144358699423,0.0409727140620487,-0.10823687710557638,0.04084772028913059,-0.3504340235683619,-0.04097363379423608,-0.5709917155173455,0.04082296556389095 +6,0.5640839138084358,-0.049158239025468925,0.7828677800964029,-0.04916889401798192,0.9311170877274866,0.04914314562262458,0.9965082376193338,-0.04853446513238548,0.972880837495202,-0.04916616243814563,0.8639788085744062,-0.04916769578909124,0.6748329768037858,0.04916631883776806,0.42824941865481525,0.049112670619177824,0.14098040833601205,-0.04915661314974432,-0.15758964784163237,-0.04910215836645524,-0.44211729417156453,0.049167256874458436,-0.6880315050419809,0.04901726434695671,-0.8707081201779896,-0.04916836055308329,-0.9765559364367731,0.048987558676669145 +7,0.6435804097001016,-0.05735127886304707,0.8669148048900676,-0.05736370968764557,0.9844748964969458,0.05733366989306202,0.9830125613962186,-0.05662354265444973,0.8623524101011013,-0.057360522844503235,0.6383331886679064,-0.057362311753939785,0.3346756078980445,0.057360705310729405,-0.008424352581826635,0.05729811572237413,-0.35043622344556025,-0.05734938200803504,-0.6499815211051329,-0.05728585142753111,-0.8707817412815769,0.05736179968686818,-0.9872245910049116,0.057186808404782824,-0.9814745826782354,-0.05736308731193051,-0.8580785015295292,0.05715215178944733 +8,0.716646462877443,-0.06554431870062523,0.9314927921233287,-0.06555852535730922,0.9985847978332577,0.06552419416349944,0.9083979252494543,-0.06471262017651398,0.6747926102147798,-0.06555488325086084,0.3352867497045722,-0.06555692771878832,-0.05831968064560095,0.06555509178369075,-0.4434207863978218,0.06548356082557043,-0.7560538458370958,-0.06554215086632575,-0.9506607234136952,-0.06546954448860699,-0.9952570736139368,0.06555634249927793,-0.8837954998144751,0.06535635246260894,-0.630638212981515,-0.06555781407077772,-0.27913703994139577,0.06531674490222553 +9,0.7825520214895382,-0.07373735853820339,0.9751514594248076,-0.07375334102697287,0.9728842744965123,0.07371471843393687,0.7773035058645293,-0.07280169769857822,0.42695559615988093,-0.07374924365721844,-0.008414741364378654,-0.07375154368363686,-0.44210757394878014,0.07374947825665208,-0.7901295718879009,0.07366900592876674,-0.9765631184680146,-0.07373491972461647,-0.9709416366328761,-0.07365323754968287,-0.7720604764071538,0.07375088531168766,-0.4199259662608264,0.07352589652043506,0.016797162274309777,-0.07375254082962494,0.4495955714985163,0.07348133801500371 +10,0.8406385789802162,-0.08193039837578153,0.9969103272494149,-0.08194815669663652,0.9083979252494543,0.0819052427043743,0.597880119293297,-0.08089077522064247,0.14097991028104204,-0.08194360406357606,-0.351095906551303,-0.08194615964848541,-0.756096402389334,0.08194386472961344,-0.9795189805814214,0.08185445103196304,-0.9579756808685278,-0.0819276885829072,-0.7048423860225389,-0.08183693061075874,-0.2791609424719786,0.0819454281240974,0.21520298315387307,0.08169544057826118,0.656332569649626,-0.08194726758847216,0.9370651902622528,0.0816459311277819 +11,0.8903257536678935,-0.09012343821335969,0.9962807383351455,-0.09014297236630017,0.8076966173381208,0.09009576697481174,0.38128344268970554,-0.08897985274270673,-0.15758909110961367,-0.09013796446993366,-0.651205087049036,-0.09014077561333396,-0.9507142339446707,0.09013825120257478,-0.973880483640077,0.09003989613515935,-0.7048423860225389,-0.09012045744119791,-0.23084919990941727,-0.09002062367183461,0.31125754017940216,0.09013997093650715,0.7625651828372738,0.0898649846360873,0.9871845140394278,-0.09014199434731937,0.9216847667495226,0.0898105242405601 +12,0.9311170877274866,-0.09831647805093785,0.973276831891701,-0.09833778803596384,0.6747949941275839,0.09828629124524917,0.14098040833601205,-0.09706893026477097,-0.4420811283313496,-0.09833232487629126,-0.87235267193123,-0.09833539157818248,-0.9952351922702056,0.09833263767553611,-0.7743367384816401,0.09822534123835565,-0.27913909284063765,-0.09831322629948865,0.3112331783813484,-0.09820431673291048,0.7929448091104607,0.09833451374891687,0.998928591634831,0.09803452869391342,0.8537481562929551,-0.09833672110616658,0.4117077786201266,0.09797511735333829 +13,0.9626050076328949,-0.106509517888516,0.9284152260643215,-0.10653260370562749,0.5149914239399227,0.10647681551568658,-0.10808810498050356,-0.10515800778683522,-0.687083374987902,-0.10652668528264887,-0.9877235049343619,-0.10653000754303102,-0.8826303969734252,0.10652702414849748,-0.4206180611819545,0.10641078634155196,0.21490718558485614,-0.10650599515777937,0.7615170332066578,-0.10638800979398635,0.9976336430538367,0.10652905656132662,0.8278965525785565,0.10620407275173954,0.31878070045897905,-0.10653144786501381,-0.31920076905929495,0.10613971046611648 +14,0.9844748964969458,-0.11470255772609414,0.8627034158040058,-0.11472741937529114,0.33465677082919276,0.11466733978612403,-0.35043622344556025,-0.11324708530889946,-0.8707105100636796,-0.11472104568900647,-0.9833283448153961,-0.11472462350787957,-0.6306776692768031,0.11472141062145881,0.016848109704656875,0.11459623144474826,0.6563366898290532,-0.11469876401607008,0.9871907111680082,-0.11457170285506223,0.8538203433783806,0.11472359937373636,0.3192214731299795,0.11437361680956565,-0.3661142998867028,-0.11472617462386102,-0.878819077933425,0.11430430357889466 +15,0.9965082376193338,-0.12289559756367231,0.7776171467029437,-0.12292223504495478,0.14098040833601205,0.12285786405656146,-0.5709959148470373,-0.12133616283096371,-0.9765596684701005,-0.1229154060953641,-0.8597001248435612,-0.12291923947272812,-0.27915480494560424,0.12291579709442016,0.4509597243093111,0.12278167654794456,0.937072081860799,-0.12289153287436079,0.921691545233522,-0.12275539591613811,0.4117430331953406,0.1229181421861461,-0.3196424667630778,0.12254316086739177,-0.8788200243019274,-0.12292090138270824,-0.966843505096389,0.12246889669167285 +16,0.9985847978332577,-0.13108863740125046,0.6750672729300327,-0.13111705071461843,-0.05831639814812038,0.1310483883269989,-0.7560538458370958,-0.12942524035302796,-0.9951756601321037,-0.13110976650172168,-0.631829332203657,-0.13111385543757664,0.11644046502904361,0.1311101835673815,0.7952826443595055,0.13096712165114086,0.9883795467216366,-0.1310843017326515,0.5843385770049826,-0.13093908897721399,-0.17416796440427548,0.13111268499855586,-0.828145851366662,0.1307127049252179,-0.9782029593462851,-0.13111562814155545,-0.5360381833105293,0.13063348980445105 +17,0.9906838288354797,-0.1392816772388286,0.557356843565802,-0.1393118663842821,-0.25528831386612294,0.13923891259743632,-0.8941040105802966,-0.13751431787509222,-0.9248955739576221,-0.13930412690807928,-0.32734634285383213,-0.1393084714024252,0.49365234586922163,0.13930457004034283,0.9812601810204711,0.13915256675433718,0.7976972276032396,-0.13927707059094224,0.07463438693068827,-0.13912278203828984,-0.6992370811936799,0.13930722781096558,-0.9989045234541498,0.13888224898304402,-0.6175217577672163,-0.13931035490040267,0.1824171610570709,0.13879808291722923 +18,0.9728842744965123,-0.14747471707640677,0.4271293810851432,-0.14750668205394574,-0.44208269011914386,0.14742943686787374,-0.9765631184680146,-0.14560339539715644,-0.7719973207328699,-0.14749848731443688,0.016828887949111148,-0.14750308736727372,0.7929273757272797,0.14749895651330416,0.9718631289483876,0.14733801185753348,0.41171080650416075,-0.14746983944923295,-0.45708328691040556,-0.14730647509936573,-0.9800425683449941,0.14750177062337533,-0.762277563471767,0.14705179304087013,0.033589575533404636,-0.1475050816592499,0.802983395784573,0.14696267603000743 +19,0.9453639820795602,-0.1556677569139849,0.2873095135416216,-0.15570149772360942,-0.6112526245897155,0.15561996113831117,-0.9983042615559901,-0.1536924729192207,-0.5501388460490069,-0.1556928477207945,0.3589635391077552,-0.15569770333212227,0.9670166078497127,0.15569334298626555,0.7689624956694915,0.15552345696072978,-0.07507677894349374,-0.15566260830752368,-0.85398380836005,-0.1554901681604416,-0.9184909903716716,0.15569631343578508,-0.21476911310049615,0.15522133709869623,0.6689032066090711,-0.1556998084180971,0.9926508641150769,0.15512726914278563 +20,0.9083979252494543,-0.16386079675156306,0.1410372937265151,-0.16389631339327304,-0.7560538458370958,0.1638104854087486,-0.9579756808685278,-0.16178155044128495,-0.2791381066992712,-0.16388720812715213,0.6575722191407843,-0.16389231929697082,0.9884351803594296,0.1638877294592269,0.4129569731385663,0.16370890206392608,-0.5434829505115776,-0.1638553771658144,-0.9990009892270116,-0.16367386122151747,-0.536084084309812,0.1638908562481948,0.42032914073565747,0.16339088115652237,0.9896212077155743,-0.16389453517694433,0.6496397801173898,0.1632918622555638 +21,0.8623554566266871,-0.17205383658914122,-0.008402319647324475,-0.17209112906293672,-0.8707135861168288,0.17200100967918602,-0.8580848122244682,-0.1698706279633492,0.016797208378505686,-0.1720815685335097,0.8764472596669446,-0.17208693526181934,0.8538015716088414,0.17208211593218822,-0.02527067595158275,0.1718943471671224,-0.878825541163884,-0.17204814602410512,-0.8493618733954124,-0.17185755428259336,0.03359241563741021,0.17208539906060455,0.8840035510959018,0.17156042521434847,0.8449048915755235,-0.17208926193579152,-0.0419824723360154,0.171456455368342 +22,0.8076966173381208,-0.18024687642671938,-0.15765323503621154,-0.18028594473260035,-0.9506607234136952,0.18019153394962348,-0.7048423860225389,-0.17795970548541346,0.3112320788581369,-0.18027592893986732,0.9890490608211937,-0.18028155122666792,0.5843714680951001,0.18027650240514956,-0.45846678700876015,0.1800797922703187,-0.9990009892270116,-0.18024091488239583,-0.44920266111697,-0.18004124734366922,0.591534118342619,0.1802799418730143,0.9871526690389019,0.1797299692721746,0.3028166029243987,-0.18028398869463874,-0.711075995509521,0.1796210484811202 +23,0.744967540438117,-0.1884399162642975,-0.3033635986463942,-0.18848076040226402,-0.9927080176514501,0.1883820582200609,-0.5077762741063465,-0.18604878300747768,0.5778655146606875,-0.18847028934622492,0.981724139138569,-0.18847616719151644,0.2226819589323564,0.1884708888781109,-0.8003795038223758,0.188265237373515,-0.8745861537498308,-0.18843368374068656,0.08344930544220765,-0.1882249404047451,0.9428359349673102,0.18847448468542405,0.6877089419367549,0.1878995133300007,-0.38169106572201283,-0.18847871545348596,-0.9985903093393598,0.18778564159389838 +24,0.6747949941275839,-0.1966329561018757,-0.44226106984409214,-0.19667557607192768,-0.995179175896744,0.19657258249049833,-0.27913909284063765,-0.19413786052954193,0.7928799450672677,-0.19666464975258252,0.855360674879517,-0.19667078315636496,-0.17416413521360993,0.19666527535107223,-0.9829320229899038,0.1964506824767113,-0.5360421255762363,-0.1966264525989773,0.5914878195936001,-0.19640863346582096,0.964778033947753,0.19666902749783374,0.10779522454742542,0.19606905738782685,-0.8866834620719409,-0.19667344221233315,-0.750238832308229,0.19595023470667658 +25,0.597880119293297,-0.20482599593945383,-0.5712263108713863,-0.20487039174159133,-0.9579756808685278,0.20476310676093573,-0.03314639477032304,-0.2022269380516062,0.9370687713766463,-0.20485901015894015,0.6252808161105047,-0.20486539912121352,-0.5435135419313573,0.2048596618240336,-0.9697770799994512,0.2046361275799076,-0.0662562899388779,-0.204819221457268,0.9250664359638998,-0.20459232652689685,0.6496954088319284,0.2048635703102435,-0.5160808782904293,0.20423860144565295,-0.9746547713981599,-0.2048681689711804,-0.09929249585428884,0.20411482781945475 +26,0.5149914239399227,-0.213019035777032,-0.687363040447367,-0.21306520741125498,-0.8825807185111588,0.21295363103137316,0.21490718558485614,-0.21031601557367044,0.9975520351651532,-0.21305337056529774,0.31938279816277887,-0.21306001508606204,-0.8270541111570393,0.21305404829699495,-0.7635339002371887,0.21282157268310392,0.41975139624441277,-0.21301199031555873,0.9857958227932273,-0.2127760195879727,0.10765548555683631,0.21305811312265324,-0.9294824764433802,0.21240814550347908,-0.6042307123349234,-0.21306289573002762,0.6049364043496628,0.21227942093223295 +27,0.4269571045120725,-0.22121207561461012,-0.7880630780016905,-0.22126002308091863,-0.7720000480512634,0.22114415530181059,0.4495988780264069,-0.2184050930957347,0.9689269466129149,-0.22124773097165537,-0.025241845016591723,-0.22125463105091062,-0.980021021501593,0.2212484347699563,-0.4052666960337766,0.22100701778630022,0.8029893012853434,-0.22120475917384944,0.7557637893862074,-0.2209597126490486,-0.4719915960910254,0.22125265593506296,-0.9638110027732304,0.22057768956130516,0.050372492104744704,-0.22125762248887482,0.984542962732691,0.22044401404501116 +28,0.33465677082919276,-0.22940511545218828,-0.8710649177872871,-0.22945483875058229,-0.6306421718625934,0.22933467957224807,0.6563366898290532,-0.22649417061779892,0.8537504996275977,-0.22944209137801294,-0.3668057990237982,-0.22944924701575914,-0.9782641612589208,0.22944282124291762,0.03369145598987282,0.22919246288949652,0.9896274201406913,-0.22939752803214017,0.30281850387948656,-0.22914340571012445,-0.8867584339408113,0.2294471987474727,-0.6050661718987956,0.2287472336191313,0.6812847264156532,-0.22945234924772204,0.8358218491694634,0.22860860715778933 +29,0.23901265731500806,-0.23759815528976647,-0.9345045174239208,-0.2376496544202459,-0.46414258238821615,0.2375252038426855,0.8222666651732259,-0.23458324813986317,0.6623110631940086,-0.23763645178437057,-0.6638928720176782,-0.23764386298060766,-0.8220609000314877,0.23763720771587896,0.4659414438740256,0.23737790799269282,0.9339702320827127,-0.23759029689043087,-0.23944338880493815,-0.23732709877120034,-0.9917550386147392,0.23764174155988246,0.00044424977577437575,0.23691677767695746,0.9917781085257773,-0.23764707600656926,0.238580124064969,0.23677320027056753 +30,0.14098040833601205,-0.24579119512734462,-0.9769571602714157,-0.24584447008990956,-0.27913909284063765,0.24571572811312292,0.937072081860799,-0.24267232566192742,0.4117093520140241,-0.2458308121907282,-0.88047989743077,-0.24583847894545624,-0.5360722981688957,0.24583159418884032,0.8054197900149034,0.24556335309588911,0.6496445578602947,-0.24578306574872158,-0.7110812250819646,-0.24551079183227623,-0.7503030753289238,0.2458362843722922,0.605773491996799,0.24508632173478354,0.8358227492358042,-0.24584180276541648,-0.48668900694371736,0.2449377933833457 +31,0.04153952971884974,-0.2539842349649228,-0.9974694514938842,-0.25403928575957324,-0.08300720853584859,0.25390625238356035,0.9936148951331387,-0.25076140318399165,0.12433087059250325,-0.25402517259708574,-0.990304707696571,-0.2540330949103048,-0.16544966758622034,0.25402598066180165,0.9845343883188172,0.2537487981990854,0.2062632387276392,-0.25397583460701234,-0.9729849743120409,-0.2536944848933521,-0.24674866147569965,0.25403082718470194,0.9640486753625022,0.2532558657926096,0.28676689088951746,-0.25403652952426364,-0.9507899820329124,0.25310238649612393 +32,-0.05831639814812038,-0.2621772748025009,-0.9955807292523456,-0.26223410142923687,0.11643391122887078,0.2620967766539978,0.9883795467216366,-0.25885048070605593,-0.17415371721047815,-0.26221953300344336,-0.9800505421981471,-0.2622277108751533,0.2312938276000737,0.262220367134763,0.9676224842415853,0.2619342433022817,-0.28761851492748053,-0.262168603465303,-0.9479058753109407,-0.26187817795442797,0.3430021588345477,0.2622253699971117,0.929153570940734,0.2614254098504358,-0.3971599170983825,-0.2622312562831109,-0.904675886036727,0.2612669796089021 +33,-0.15758964784163237,-0.27037031464007905,-0.9713334101765015,-0.27042891709890055,0.3112331783813484,0.2702873009244352,0.921691545233522,-0.26693955822812016,-0.45708167212856854,-0.270413893409801,-0.8509607654078644,-0.27042232684000184,0.5915211130983306,0.2704147536077243,0.7580513358944515,0.27011968840547806,-0.7110812250819646,-0.27036137232359375,-0.6432410323016756,-0.27006187101550383,0.8129324568298196,0.2704199128095214,0.5153195330173368,0.26959495390826194,-0.8942962102800252,-0.27042598304195814,-0.3730925694703565,0.2694315727216803 +34,-0.25528831386612294,-0.2785633544776572,-0.9252720367786772,-0.2786237327685642,0.4936245608648517,0.27847782519487263,0.7976972276032396,-0.27502863575018444,-0.6991798825794565,-0.27860825381615856,-0.6186881032575886,-0.2786169428048504,0.8583602212080104,0.27860914008068566,0.39754777343988285,0.27830513350867436,-0.9604464515116701,-0.2785541411818845,-0.14885163195315587,-0.2782455640765797,0.9988820594408003,0.27861445562193116,-0.10867850831880253,0.27776449796608804,-0.9708310220032218,-0.27862070980080533,0.3587005257547418,0.27759616583445845 +35,-0.35043622344556025,-0.28675639431523536,-0.8584310482029581,-0.28681854843822785,0.6563366898290532,0.28666834946531006,0.6241059599443225,-0.28311771327224866,-0.8788224364528556,-0.2868026142225162,-0.31139667851909636,-0.28681155876969894,0.9896831240184422,0.28680352655364705,-0.04210985461304808,0.28649057861187066,-0.9746608898702935,-0.2867469100401752,0.38944169952439245,-0.2864292571376556,0.8358934206718184,0.2868089984343409,-0.6883539324634889,0.2859340420239141,-0.5907688344251345,-0.2868154365596526,0.8980069333782202,0.2857607589472367 +36,-0.44208269011914386,-0.29494943415281355,-0.772311548952134,-0.2950133641078915,0.7928827461599658,0.2948588737357475,0.41171080650416075,-0.2912067907943129,-0.9799623994319073,-0.29499697462887375,0.03365301791329512,-0.29500617473454743,0.964756822704603,0.2949979130266083,-0.4733831665735062,0.29467602371506696,-0.750244349901775,-0.2949396788984659,0.8128688294673042,-0.29461295019873146,0.38090316125745755,0.29500354124675066,-0.9872963182846276,0.29410358608174025,0.0671411670007659,-0.2950101633184998,0.9554228288939755,0.29392535206001486 +37,-0.5293120126861656,-0.3031424739903917,-0.6688475973167545,-0.30320817977755515,0.8978190694876738,0.3030493980061849,0.1737174692044557,-0.29929586831637717,-0.9935652398439103,-0.3031913350352314,0.3746221319845234,-0.303200790699396,0.7875166321642914,0.3031922994995697,-0.8104031466742486,0.30286146881826326,-0.34214182739086046,-0.30313244775675663,0.9965395211529737,-0.3027966432598073,-0.20714753103688738,0.3031980840591604,-0.8835872742435591,0.3022731301395664,0.6934736284765324,-0.30320489007734697,0.5001375645610977,0.30208994517279303 +38,-0.6112526245897155,-0.3113355138279698,-0.5503627705955039,-0.31140299544721883,0.9669621797468185,0.31123992227662234,-0.07507677894349374,-0.3073849458384414,-0.9184158564675367,-0.311385695441589,0.6701665989166474,-0.31139540666424453,0.48594488132549146,0.3113866859725311,-0.9860671637471322,0.31104691392145956,0.14972894707872506,-0.31132521661504736,0.886279928501392,-0.3109803363208832,-0.7228356310681627,0.31139262687157016,-0.41952270897421107,0.31044267419739247,0.9936546066554385,-0.3113996168361942,-0.2235326511039411,0.31025453828557126 +39,-0.6870858023216195,-0.319528553665548,-0.4195179825584126,-0.31959781111688246,0.9975555593250707,0.31943044654705977,-0.3192031166052062,-0.3154740233605057,-0.7612271199060563,-0.31958005584794663,0.8844503001835399,-0.3195900226290931,0.10765311868424986,0.3195810724454924,-0.9653994939682055,0.31923235902465585,0.6049408533238516,-0.3195179854733381,0.5146112237601901,-0.3191640293819591,-0.9860164490556457,0.3195871696839799,0.21563681076800184,0.31861221825521857,0.8265042970427732,-0.3195943435950414,-0.827250269846316,0.31841913139834943 +40,-0.7560538458370958,-0.32772159350312613,-0.2792517250601661,-0.3277926267865461,0.9883795467216366,0.3276209708174972,-0.5434829505115776,-0.3235631008825699,-0.5360402318489593,-0.32777441625430426,0.9914903568075338,-0.32778463859394164,-0.28763470431984006,0.3277754589184538,-0.7525151901656907,0.32741780412785215,0.9120421406254627,-0.3277107543316288,-0.00884255333616546,-0.32734772244303495,-0.904753353516636,0.3277817124963896,0.7628526518205978,0.32678176231304473,0.2706361020450086,-0.32778907035388866,-0.987046977334887,0.3265837245111276 +41,-0.8174676492431415,-0.33591463334070426,-0.13271407584806688,-0.33598744245620976,0.9397999606024338,0.3358114950879346,-0.7339716468720321,-0.3316521784046341,-0.26297046634326493,-0.3359687766606618,0.9783076722890778,-0.3359792545587902,-0.6375113320253639,0.33596984539141506,-0.38980075095393485,0.33560324923104845,0.9958437033202947,-0.33590352318991956,-0.5296882108735904,-0.3355314155041108,-0.5074338816768452,0.3359762553087993,0.9989524628211233,0.3349513063708709,-0.4125164805502681,-0.33598379711273585,-0.6171723028966587,0.33474831762390583 +42,-0.8707135861168288,-0.34410767317828245,0.016804045393007727,-0.34418225812587344,0.8537535157642845,0.34400201935837205,-0.878825541163884,-0.3397412559266984,0.03358966772876154,-0.3441631370670194,0.8465007074276576,-0.3441738705236387,-0.8867389380069325,0.34416423186437645,0.05052527678295439,0.3437886943342448,0.8358279961789783,-0.34409629204821024,-0.8943018242949589,-0.3437151085651867,0.06714684399710508,0.3441707981212091,0.8276470905245503,0.34312085042869694,-0.9016561165949415,-0.34417852387158304,0.08389076892147215,0.342912910736684 +43,-0.9152596406576812,-0.3523007130158606,0.1659447840019286,-0.35237707379553707,0.7336706124276289,0.3521925436288095,-0.9690383198285016,-0.34783033344876263,0.3271493368210481,-0.352357497473377,0.6120516596380092,-0.35236848648848723,-0.9959699632980901,0.35235861833733784,0.4807914291034945,0.3519741394374411,0.4711724450525941,-0.352289060906501,-0.9951402597944171,-0.3518988016262626,0.6182712452360382,0.3523653409336188,0.3188004165445946,0.35129039448652305,-0.9667327922395572,-0.3523732506304303,0.7399361865388727,0.3510775038494622 +44,-0.9506607234136952,-0.36049375285343876,0.31135876051789274,-0.3605718894652007,0.5843385770049826,0.36038306789924696,-0.9990009892270116,-0.3559194109708269,0.5914857299880178,-0.36055185787973465,0.30338854840613816,-0.36056310245333584,-0.9479592307777791,0.3605530048102991,0.8153292215615305,0.3601595845406374,-0.00884255333616546,-0.36048182976479165,-0.8024611244318575,-0.36008249468733844,0.953415712738609,0.3605598837460286,-0.32006339736086886,0.3594599385443492,-0.5771399300773457,-0.3605679773892775,0.9989153738134007,0.3592420969622404 +45,-0.9765631184680146,-0.3686867926910169,0.4497802905222859,-0.36876670513486437,0.41171080650416075,0.3685735921696844,-0.9668506156980455,-0.36400848849289114,0.8029864644882515,-0.3687462182860923,-0.042061812111804486,-0.36875771841818433,-0.7502865794506939,0.3687473912832605,0.9875302409335939,0.3683450296438337,-0.4866925862734028,-0.3686745986230824,-0.3730953133607881,-0.3682661877484143,0.9555046418393013,0.3687544265584383,-0.8283949868397046,0.3676294826021753,0.08389085926040919,-0.36876270414812473,0.7218543333935213,0.3674066900750186 +46,-0.9927080176514501,-0.376879832528595,0.5781007248703209,-0.37696152080452805,0.22266942537659834,0.3767641164401218,-0.8745861537498308,-0.37209756601495536,0.9427588096111604,-0.37694057869244985,-0.38241198550762484,-0.3769523343830329,-0.43416017453398315,0.3769417777562218,0.963108266307061,0.37653047474703,-0.8453833000935627,-0.3768673674813731,0.16631531702076296,-0.3764498808094902,0.6238083095043726,0.3769489693708481,-0.9988802582838266,0.3757990266600014,0.7054664666571613,-0.3769574309069719,0.05743018757134257,0.37557128318779676 +47,-0.998934106468155,-0.3850728723661732,0.6934382632489828,-0.3851563364741917,0.024750916933274518,0.3849546407105592,-0.7279441607506484,-0.38018664353701964,0.9983173180440714,-0.38513493909880747,-0.6763929563914927,-0.3851469503478814,-0.0494894243747762,0.3851361642291832,0.7469258543626135,0.3847159198502263,-0.9970946982774902,-0.3850601363396638,0.6566710856693903,-0.38463357387056607,0.07419778758181797,0.38514351218325776,-0.7619897937807965,0.38396857071782753,0.9952501715674335,-0.3851521576658192,-0.6378122754269473,0.38373587630057493 +48,-0.995179175896744,-0.3932659122037514,0.7932026731992672,-0.39335115214385535,-0.17415433246214224,0.39314516498099666,-0.5360421255762363,-0.38827572105908387,0.964699113695868,-0.39332929950516504,-0.8883581872850486,-0.39334156631272993,0.34299461771945594,0.39333055070214445,0.38202617615916196,0.3929013649534226,-0.9046825394297735,-0.3932529051979546,0.9533410899004228,-0.39281726693164193,-0.5013321562268207,0.3933380549956675,-0.21433520069343104,0.3921381147756537,0.8169521695766363,-0.3933468844246663,-0.9907904722933045,0.39190046941335316 +49,-0.9814807439622717,-0.40145895204132953,0.8751534611529963,-0.401545967813519,-0.36611659819807135,0.4013356892514341,-0.31081158731133707,-0.39636479858114815,0.8449072106375181,-0.40152365991152267,-0.9926059243487779,-0.40153618227757853,0.6813273514434225,0.40152493717510584,-0.05893712767182949,0.4010868100566189,-0.5907725430230472,-0.40144567405624526,0.9688222283847042,-0.40100095999271784,-0.9017323544487316,0.40153259780807726,0.42073223231919765,0.4003076588334798,0.25442879700417614,-0.40154161118351356,-0.8120884444995462,0.40006506252613133 +50,-0.9579756808685278,-0.40965199187890766,0.9374501890881459,-0.40974078348318266,-0.5434829505115776,0.40952621352187146,-0.0662562899388779,-0.4044538761032124,0.6496422627988417,-0.4097180203178803,-0.9764956526027355,-0.40973079824242703,0.912093477404205,0.4097193236480672,-0.4881657078253393,0.4092722551598152,-0.13222082417153896,-0.409638442914536,0.6985483245281671,-0.4091846530537937,-0.9871314982589727,0.409727140620487,0.8842114280468097,0.4084772028913059,-0.42775641435897105,-0.4097363379423608,-0.19760167846944277,0.4082296556389095 +51,-0.9248988414362043,-0.4178450317164858,0.978693806799069,-0.4179355991528463,-0.6991823526478141,0.41771673779230895,0.18241850263485546,-0.4125429536252766,0.3963467064222014,-0.41791238072423786,-0.8419808161894278,-0.4179254142072756,0.9988600983995086,0.4179137101210285,-0.8201976664867113,0.41745770026301154,0.35870316379961587,-0.41783121177282667,0.22223692462298353,-0.4173683461148695,-0.7276972097758293,0.41792168343289676,0.9870805524007827,0.416646746949132,-0.9087611001707453,-0.417931064701208,0.5229225472857961,0.41639424875168773 +52,-0.8825807185111588,-0.426038071554064,0.9979580715479425,-0.42613041482250996,-0.8270075607814866,0.4259072620627463,0.41975139624441277,-0.4206320311473409,0.10764667918092603,-0.4261067411305955,-0.6053719543358926,-0.4261200301721241,0.9279286727991368,0.4261080965939899,-0.9889235164634342,0.42564314536620784,0.7618041070624366,-0.42602398063111746,-0.3196234686318003,-0.4255520391759454,-0.21405734993781209,0.4261162262453065,0.6873862432114228,0.42481629100695817,-0.9623612407883293,-0.42612579146005525,0.9628348927337654,0.4245588418644659 +53,-0.8314441407891479,-0.4342311113916421,0.9948103494797972,-0.4343252304921736,-0.921862587240073,0.4340977863331838,0.6309861810683414,-0.4287211086694051,-0.19066910531273035,-0.4343011015369531,-0.29535897386302384,-0.4343146461369726,0.710497713064794,0.4343024830669512,-0.9607489632091267,0.43382859046940414,0.978388836069306,-0.43421674948940814,-0.7672106142916351,-0.43373573223702133,0.3743589007021114,0.43431076905771626,0.10735355073144076,0.4329858350647843,-0.5633478525539742,-0.4343205182189024,0.8860685998674347,0.4327234349772441 +54,-0.7720000480512634,-0.44242415122922024,0.9693213316461975,-0.44252004616183727,-0.9799658614510169,0.44228831060362117,0.8029893012853434,-0.4368101861914694,-0.4719529864892929,-0.44249546194331074,0.050467633252835764,-0.44250926210182123,0.3808947868653324,0.4424968695399126,-0.7412837235565591,0.44201403557260044,0.9554298555028464,-0.44240951834769887,-0.9885082559041144,-0.4419194252980972,0.8320008169535862,0.44250531187012593,-0.5164613983033893,0.4411553791226103,0.10061683328952202,-0.44251524497774963,0.33381817042946166,0.4408880280900223 +55,-0.7048423860225389,-0.45061719106679843,0.9220634464366241,-0.4507148618315009,-0.9990009892270116,0.45047883487405865,0.9250664359638998,-0.4448992637135336,-0.7110787129773225,-0.45068982234966826,0.39017480898243306,-0.4507038780666697,-0.008843051063390898,0.4506912560128739,-0.37422459858628676,0.45019948067579674,0.6985483245281671,-0.45060228720598955,-0.9182444425414257,-0.450103118359173,0.999000911029376,0.4506998546825357,-0.9296466542620796,0.4493249231804365,0.7172598502555733,-0.4507099717365969,-0.39756652080531235,0.4490526212028005 +56,-0.6306421718625934,-0.45881023090437656,0.8540980040709283,-0.45890967750116457,-0.9782091000924481,0.45866935914449614,0.9896274201406913,-0.45298834123559784,-0.8866858958061263,-0.4588841827560259,0.6825715043442359,-0.4588984940315183,-0.3971847656702626,0.45888564248583524,0.06734481270432775,0.45838492577899304,0.2706378009844662,-0.45879505606428034,-0.5771435531187373,-0.4582868114202489,0.81702124544593,0.4588943974949454,-0.9636918813633668,0.4574944672382626,0.9965643521521099,-0.4589046984954441,-0.9156081662497405,0.45721721431557866 +57,-0.5501407895839824,-0.4670032707419547,0.7669513618598439,-0.46710449317082825,-0.9184191010545203,0.4668598834149335,0.9926581645156499,-0.4610774187576621,-0.9830880683359124,-0.4670785431623835,0.8922032825138873,-0.46709310999636683,-0.7228197390787221,0.4670800289587966,0.49550548150247814,0.46657037088218933,-0.22353429506351763,-0.466987824922571,-0.06581362102312688,-0.4664705044813248,0.3496325529793983,0.4670889403073552,-0.6047123328311921,0.46566401129608875,0.8071690674839651,-0.4670994252542913,-0.9423140861844995,0.4653818074283569 +58,-0.46414258238821615,-0.47519631057953293,0.6625806455104198,-0.4752993088404918,-0.822014630696619,0.475050407685371,0.9339702320827127,-0.46916649627972634,-0.9916739116044826,-0.47527290356874113,0.9936513314685794,-0.4752877259612153,-0.9343373690512425,0.4752744154317579,0.8250081373332234,0.47475581598538563,-0.6629773996488669,-0.47518059378086175,0.4649281015032408,-0.4746541975424007,-0.23989284923453888,0.4752834831197649,0.000888499463940174,0.4738335553539149,0.2381495580135211,-0.4752941520131385,-0.4633532894386687,0.47354640054113506 +59,-0.3735068159343654,-0.48338935041711106,0.543329796301948,-0.48349412451015555,-0.6928390310300134,0.48324093195580836,0.8172125542301782,-0.4772555738017906,-0.9116764778034536,-0.48346726397509876,0.9746146112182116,-0.48348234192606393,-0.9983436726660347,0.4834688019047193,0.9902468918556799,0.48294126108858193,-0.9401005146550224,-0.4833733626391525,0.8585388360752993,-0.4828378906034765,-0.7456167774502503,0.4834780259321747,0.606126972887711,0.48200309941174097,-0.4428754097801683,-0.4834888787719857,0.26425319770781214,0.48171099365391323 +60,-0.27913909284063765,-0.49158239025468925,0.41187693121807495,-0.4916889401798191,-0.5360421255762363,0.49143145622624584,0.6496445578602947,-0.48534465132385485,-0.7502416994418167,-0.4916616243814564,0.8374014111729101,-0.4916769578909125,-0.9047334619532951,0.49166318837768064,0.9583217514371605,0.49112670619177823,-0.9870542365219581,-0.49156613149744316,0.9989227202859935,-0.49102158366455245,-0.9908753137732103,0.4916725687445844,0.96416722649504,0.49017264346956707,-0.9156091522350825,-0.49168360553083296,0.8500555360929072,0.4898755867666914 +61,-0.18198230420236258,-0.4997754300922674,0.27117419821302213,-0.4998837558494828,-0.3578749121544809,0.4996219804966832,0.4416848093849402,-0.4934337288459191,-0.5217900644772929,-0.4998559847878139,0.5986494594932273,-0.499871573855761,-0.6682857308827129,0.499857574850642,0.7355891965505875,0.4993121512949746,-0.7923426565683493,-0.4997589003557339,0.8446733932970993,-0.49920527672562837,-0.8899925953334138,0.4998671115569942,0.9289888433216492,0.49834218752739323,-0.9577176036061816,-0.49987833228968015,0.9796991496596853,0.49804017987946964 +62,-0.08300720853584859,-0.5079684699298456,0.1243814773329888,-0.5080785715191465,-0.16544035532471615,0.5078125047671207,0.2062632387276392,-0.5015228063679833,-0.24672847707490414,-0.5080503451941715,0.2873085224446329,-0.5080661898206096,-0.3263303771762547,0.5080519613236033,0.36639656967466055,0.5074975963981708,-0.40363796037059346,-0.5079516692140247,0.4412868415479837,-0.5073889697867042,-0.4782098580959155,0.5080616543694039,0.5149387079073461,0.5065117315852192,-0.5493965012509412,-0.5080730590485273,0.5836143892093713,0.5062047729922479 +63,0.01679726771981963,-0.5161615097674237,-0.02520458337738688,-0.5162733871888101,0.03358978639461202,0.516003029037558,-0.04198278109398478,-0.5096118838900476,0.05037263036519647,-0.5162447056005292,-0.058869887187256775,-0.5162608057854581,0.06714536773211494,0.5162463477965646,-0.07574773759957326,0.5156830415013671,0.08389138589188909,-0.5161444380723154,-0.09225768593340686,-0.51557266284778,0.10062534075848271,0.5162561971818137,-0.10912011810000993,0.5146812756430454,0.11731436019975566,-0.5162677858073746,-0.1256508450615037,0.514369366105026 +64,0.11643391122887078,-0.5243545496050018,-0.17422460348296592,-0.5244682028584737,0.2312808093290893,0.5241935533079956,-0.28761851492748053,-0.5177009614121119,0.3429741007571249,-0.5244390660068867,-0.3979100537088547,-0.5244554217503066,0.45002033546831993,0.524440734269526,-0.5028102313372707,0.5238684866045634,0.5508811950736693,-0.524337206930606,-0.598590720761377,-0.5237563559088559,0.644309213076718,0.5244507399942234,-0.6886762241376951,0.5228508197008715,0.7288504449610587,-0.5244625125662218,-0.7674890386015542,0.5225339592178042 +65,0.21490718558485614,-0.53254758944258,-0.3193319146002766,-0.5326630185281375,0.41975139624441277,0.532384077578433,-0.515371522560424,-0.5257900389341761,0.6049387161914506,-0.5326334264132444,-0.6887018060562125,-0.5326500377151552,0.7618469872838072,0.5326351207424874,-0.829760294082276,0.5320539317077598,0.8829960750480726,-0.5325299757888968,-0.9283688503192838,-0.5319400489699319,0.9629173403748043,0.5326452828066331,-0.9873678508639058,0.5310203637586977,0.9975967768548264,-0.532657239325069,-0.9974755279933509,0.5306985523305824 +66,0.3112331783813484,-0.5407406292801581,-0.4572677193544254,-0.5408578341978011,0.5914878195936001,0.5405746018488704,-0.7110812250819646,-0.5338791164562403,0.8128659577678583,-0.540827786819602,-0.8959853140869721,-0.5408446536800037,0.9533947513031344,0.5408295072154486,-0.9915002735701108,0.5402393768109561,0.9989227202859935,-0.5407227446471875,-0.9843237000653864,-0.5401237420310077,0.9451507373741426,0.5408398256190428,-0.8833788744242181,0.5391899078165239,0.7971577567141399,-0.5408519660839163,-0.6921944430119904,0.5388631454433606 +67,0.4044494321394655,-0.5489336691177362,-0.5849342769425852,-0.5490526498674647,0.7396434901233433,0.5487651261193078,-0.862579341093853,-0.5419681939783045,0.948182304255887,-0.5490221472259597,-0.994626504274366,-0.5490392696448522,0.9944224473409904,0.5490238936884101,-0.9558268025539101,0.5484248219141524,0.8702782449500921,-0.5489155135054783,-0.74995133358089,-0.5483074350920835,0.5972157896512418,0.5490343684314526,-0.4191193689553377,0.5473594518743499,0.22180298765725523,-0.5490466928427634,-0.015466410103024505,0.5470277385561387 +68,0.4936245608648517,-0.5571267089553144,-0.699464471714092,-0.5572474655371285,0.8583119087818731,0.5569556503897453,-0.9604464515116701,-0.5500572715003689,0.9988003494298336,-0.5572165076323171,-0.9726646810932612,-0.5572338856097008,0.8784527043102142,0.5572182801613713,-0.7298426758512726,0.5566102670173487,0.5285593032355248,-0.557108282363769,-0.2943801043924914,-0.5564911281531594,0.04065618459701715,0.5572289112438623,0.21607059585732746,0.5555289959321761,-0.45786919226212597,-0.5572414196016107,0.6695612427843491,0.5551923316689169 +69,0.5778675561486627,-0.5653197487928925,-0.7982862024068306,-0.5654422812067921,0.9427621401971263,0.5651461746601827,-0.9985976534212823,-0.5581463490224331,0.9601985340675653,-0.5654108680386748,-0.8327628160644499,-0.5654285015745493,0.6237945946919845,0.5654126666343328,-0.35854264273329634,0.564795712120545,0.05743060993875231,-0.5653010512220596,0.2480188179788237,-0.5646748212142353,-0.5301057954227266,0.5654234540562721,0.7631399703650499,0.5636985399900022,-0.9221983366571437,-0.5654361463604578,0.995287426852285,0.5633569247816952 +70,0.6563366898290532,-0.5735127886304707,-0.8791801459964508,-0.5736370968764557,0.9896274201406913,0.5733366989306201,-0.9746608898702935,-0.5662354265444973,0.8358250433694887,-0.5736052284450324,-0.5918846502765073,-0.5736231175393979,0.27065303457104034,0.5736070531072941,0.0841453084131369,0.5729811572237413,-0.42775909963357084,-0.5734938200803504,0.7172643529108057,-0.5728585142753112,-0.9156865698618166,0.5736179968686818,0.9989761370083196,0.5718680840478282,-0.9528031935758104,-0.5736308731193052,0.7869202203314191,0.5715215178944734 +71,0.7282479242607367,-0.581705828468049,-0.9403295989069247,-0.5818319125461193,0.9970393777931617,0.5815272232010575,-0.8901244328804739,-0.5743245040665615,0.6367897908441639,-0.58179958885139,-0.2792377631814792,-0.5818177335042464,-0.12521868858818497,0.5818014395802554,0.5100794410076592,0.5811666023269376,-0.8082184629954489,-0.581686588938641,0.9749520813323892,-0.581042207336387,-0.981391680980145,0.5818125396810916,0.8273974652538418,0.5800376281056544,-0.5352898205951903,-0.5818255998781523,0.15627410496418287,0.5796861110072515 +72,0.7928827461599658,-0.5898988683056271,-0.9803612762562817,-0.590026728215783,0.9647025217927181,0.589717747471495,-0.750244349901775,-0.5824135815886258,0.38087200282269595,-0.5899939492577475,0.06726798001806399,-0.5900123494690949,-0.5013211341285863,0.5899958260532167,0.8344538008369035,0.5893520474301339,-0.9907977590117201,-0.5898793577969318,0.9450767614268256,-0.5892259003974629,-0.7042684431161633,0.5900070824935013,0.3183792970899596,0.5882071721634805,0.1339787191455119,-0.5900203266369995,-0.5582321741403956,0.5878507041200297 +73,0.8495953457517539,-0.5980919081432052,-0.9983761528744305,-0.5982215438854467,0.893906020543249,0.5979082717419324,-0.5637177069956406,-0.5905026591106901,0.09093205312158407,-0.5981883096641052,0.40561717293614663,-0.5982069654339435,-0.7982759956420706,0.5981902125261781,0.9926835730138766,0.5975374925333302,-0.9307952083420419,-0.5980721266552226,0.6364501473575123,-0.5974095934585388,-0.1811239761413819,0.598201625305911,-0.3204842648403412,0.5963767162213066,0.7402349737968508,-0.5982150533958468,-0.9731786410957026,0.5960152972328079 +74,0.8978190694876738,-0.6062849479807834,-0.9939696534702979,-0.6064163595551103,0.7874723071382904,0.6060987960123698,-0.34214182739086046,-0.5985917366327543,-0.20713058606638482,-0.6063826700704628,0.694783428218958,-0.606401581398792,-0.9692006299408723,0.6063845989991394,0.9532642929099907,0.6057229376365265,-0.6429015280524647,-0.6062648955135133,0.1401019539544995,-0.6055932865196146,0.40529230666910204,0.6063961681183208,-0.8286439589485542,0.6045462602791328,0.9983471537810038,-0.6064097801546939,-0.8658957840905578,0.6041798903455861 +75,0.937072081860799,-0.6144779878183615,-0.967240738520651,-0.6146111752247739,0.6496445578602947,0.6142893202828072,-0.0992932260959275,-0.6066808141548186,-0.4866908668879625,-0.6145770304768204,0.8997040146787455,-0.6145961973636406,-0.9871097955610528,0.6145789854721008,0.7240445676402713,0.6139083827397228,-0.19760313172099261,-0.6144576643718039,-0.39756944468818123,-0.6137769795806904,0.8501283264271917,0.6145907109307305,-0.9988557961286462,0.6127158043369588,0.7869210677373633,-0.6146045069135412,-0.29395397255195843,0.6123444834583643 +76,0.9669621797468185,-0.6226710276559396,-0.9187896818311375,-0.6228059908944377,0.48591753013236094,0.6224798445532447,0.14972894707872506,-0.6147698916768828,-0.7227765020581024,-0.622771390883178,0.9955313738379434,-0.6227908133284891,-0.8491760290380246,0.6227733719450622,0.3506633729017655,0.6220938278429191,0.2960754029059257,-0.6226504332300947,-0.8179773555910103,-0.6219606726417664,0.9979900634187405,0.6227852537431403,-0.7617018738211153,0.6208853483947849,0.2053937075559733,-0.6227992336723884,0.4357300847355405,0.6205090765711425 +77,0.9871907111680082,-0.6308640674935178,-0.8497045896807232,-0.6310008065641012,0.30281850387948656,0.6306703688236821,0.38944169952439245,-0.622858969198947,-0.894298664909385,-0.6309657512895356,0.9706460000549085,-0.6309854292933377,-0.5771760392173167,0.6309677584180234,-0.09253693157904311,0.6302792729461154,0.7172643529108057,-0.6308432020883854,-0.9971220635732705,-0.6301443657028423,0.7972251589037169,0.63097979655555,-0.21390124601824936,0.6290548924526111,-0.4727335226541953,-0.6309939604312357,0.931591678220842,0.6286736696839207 +78,0.9975555593250707,-0.639057107331096,-0.7615369643004859,-0.6391956222337649,0.10764705947598698,0.6388608930941195,0.6049408533238516,-0.6309480467210113,-0.985935791470394,-0.6391601116958933,0.8280653587341291,-0.6391800452581862,-0.21405264375427785,0.6391621448909848,-0.5173125967036866,0.6384647180493117,0.962841973854285,-0.6390359709466762,-0.8821646657745696,-0.6383280587639182,0.3179665700717952,0.6391743393679598,0.4211352409319533,0.6372244365104371,-0.9285267904950506,-0.6391886871900828,0.9275404378438044,0.6368382627966989 +79,0.9979531620813837,-0.6472501471686741,-0.6562668604783414,-0.6473904379034287,-0.09181593348888199,0.647051417364557,0.7828277148470415,-0.6390371242430756,-0.98950221014382,-0.6473544721022508,0.5850780048431273,-0.6473746612230347,0.18286495756679882,0.6473565313639461,-0.8390883258456986,0.646650163152508,0.9726822993104476,-0.6472287398049669,-0.5070119565612233,-0.646511751824994,-0.2723668896418689,0.6473688821803695,0.8844191306262055,0.6453939805682634,-0.9476194001347644,-0.64738341394893,0.42575034938048145,0.645002855909477 +80,0.9883795467216366,-0.6554431870062523,-0.5362584177972577,-0.6555852535730922,-0.28761851492748053,0.6552419416349944,0.9120421406254627,-0.6471262017651398,-0.9046793433712547,-0.6555488325086085,0.2711472665395091,-0.6555692771878833,0.5509122029238168,0.6555509178369076,-0.9937967065477539,0.6548356082557043,0.7443760744142769,-0.6554215086632575,0.017684413882941338,-0.6546954448860699,-0.7675547587590115,0.6555634249927792,0.9870082411047749,0.6535635246260895,-0.5210317989295246,-0.6555781407077773,-0.30450685472212724,0.6531674490222552 +81,0.9689303696458874,-0.6636362268438304,-0.40420676715700077,-0.6637800692427559,-0.4719546538086299,0.6634324659054318,0.9845502035041135,-0.655215279287204,-0.7390441652176856,-0.6637431929149661,-0.0756613181423876,-0.6637638931527318,0.8319825248998615,0.6637453043098689,-0.9506344036314159,0.6630210533589006,0.333820625478312,-0.6636142775215483,0.5371647495481374,-0.6628791379471458,-0.9946136679525669,0.6637579678051889,0.6870634089296191,0.6617330686839156,0.15060519865870756,-0.6637728674666244,-0.8713589015723978,0.6613320421350334 +82,0.9397999606024338,-0.6718292666814085,-0.26307750394448315,-0.6719748849124195,-0.6374754499814631,0.6716229901758692,0.9958437033202947,-0.6633043568092682,-0.5073923728444012,-0.6719375533213237,-0.4132956219015522,-0.6719585091175804,0.9817010998307771,0.6719396907828301,-0.7181952817455959,0.6712064984620969,-0.15846575497606952,-0.6718070463798391,0.8982078288685432,-0.6710628310082216,-0.8742254077153917,0.6719525106175986,0.10691185574473533,0.6699026127417418,0.7514102180466009,-0.6719675942254717,-0.9706203634271584,0.6694966352478117 +83,0.901279381009204,-0.6800223065189867,-0.11604008715482107,-0.6801697005820831,-0.777582111633649,0.6798135144463067,0.9452204649546138,-0.6713934343313326,-0.23041673094687917,-0.6801319137276812,-0.700815940964827,-0.6801531250824289,0.9764306567478603,0.6801340772557916,-0.34275931711095253,0.6793919435652932,-0.6119541919258943,-0.6799998152381297,0.9943236504843816,-0.6792465240692975,-0.44844506094033215,0.6801470534300085,-0.5168418164672298,0.6780721567995678,0.99881527077865,-0.680162320984319,-0.5490253300714311,0.6776612283605898 +84,0.8537535157642845,-0.6882153463565649,0.033603339824755434,-0.6883645162517469,-0.8866890282972477,0.6880040387167441,0.8358279961789783,-0.6794825118533968,0.06714135128722254,-0.6883262741340388,-0.9033591214400818,-0.6883477410472774,0.8170032827273709,0.6883284637287529,0.10092201395171065,0.6875773886684896,-0.9156149000436888,-0.6881925840964205,0.7971627609347764,-0.6874302171303734,0.13399004746660154,0.6883415962424182,-0.9298106487491817,0.6862417008573939,0.7764618947443824,-0.6883570477431661,0.16718891794107527,0.685825821473368 +85,0.7976972276032396,-0.696408386194143,0.18249210823637022,-0.6965593319214104,-0.9604464515116701,0.6961945629871815,0.674467790868045,-0.687571589375461,0.3587018965746475,-0.6965206345403965,-0.9963658762003635,-0.696542357012126,0.5285890546371254,0.6965228502017142,0.5245091871637875,0.6957628337716859,-0.9951011474447816,-0.6963853529547112,0.3648779530547339,-0.6956139101914494,0.669618577375581,0.6965361390548279,-0.9635725699079928,0.6944112449152201,0.18892635706005056,-0.6965517745020133,0.7936858705845179,0.6939904145861462 +86,0.7336706124276289,-0.7046014260317212,0.3272824973266379,-0.7047541475910741,-0.9959139055683656,0.704385087257619,0.4711724450525941,-0.6956606668975253,0.6182206697454172,-0.704714994946754,-0.9685587107896996,-0.7047369729769745,0.15672223743886163,0.7047172366746757,0.8436635415262688,0.7039482788748822,-0.8309519285855932,-0.704578121813002,-0.17502795585869962,-0.7037976032525252,0.9713300731598357,0.7047306818672376,-0.6043583745108809,0.7025807889730461,-0.4874641984053844,-0.7047465012608606,0.994273315837164,0.7021550076989244 +87,0.6623134030119022,-0.7127944658692994,0.46472282710614904,-0.7129489632607379,-0.9916774149981173,0.7125756115280564,0.23858187869042435,-0.7037497444195895,0.8225156317042877,-0.7129093553531117,-0.8233093712125975,-0.7129315889418231,-0.23988757504153696,0.7129116231476369,0.9948395954920617,0.7121337239780784,-0.46335669714700006,-0.7127708906712926,-0.6633092018857201,-0.711981296313601,0.933728029047689,0.7129252246796474,0.0013327489768817263,0.7107503330308722,-0.9345927245225544,-0.7129412280197078,0.6613115650481177,0.7103196008117025 +88,0.5843385770049826,-0.7209875057068775,0.5917264840718246,-0.7211437789304014,-0.9479058753109407,0.7207661357984939,-0.00884255333616546,-0.7118388219416538,0.9533377219411896,-0.7211037157594693,-0.5782300043075936,-0.7211262049066717,-0.5986244140722671,0.7211060096205982,0.9479373206068049,0.7203191690812748,0.017684413882941338,-0.7209636595295833,-0.9559467647718869,-0.7201649893746769,0.5699479208651215,0.7211197674920572,0.6064803342469434,0.7189198770886984,-0.9421676888826179,-0.721135954778555,-0.026524693830696975,0.7184841939244808 +89,0.5005252330932832,-0.7291805455444557,0.7054412398918571,-0.7293385946000651,-0.8663443196650842,0.7289566600689312,-0.2557171982245236,-0.719927899463718,0.9990009927563126,-0.729298076165827,-0.2630376043797571,-0.7293208208715201,-0.8628516206780508,0.7293003960935596,0.7122952316126568,0.7285046141844711,0.4943957636288428,-0.7291564283878741,-0.9666269156172481,-0.7283486824357528,0.00706860641972282,0.7293143103044668,0.9642855874883276,0.7270894211465245,-0.5066264673849542,-0.7293306815374021,-0.700127211500532,0.726648787037259 +90,0.41171080650416075,-0.7373735853820338,0.803313306304995,-0.7375334102697287,-0.750244349901775,0.7371471843393688,-0.4866925862734028,-0.7280169769857823,0.9554264801644313,-0.7374924365721846,0.08404930829342079,-0.7375154368363687,-0.9908535287651854,0.737494782566521,0.3348310340436826,0.7366900592876674,0.8500617877833526,-0.7373491972461648,-0.6921995337210113,-0.7365323754968286,-0.5582799756131686,0.7375088531168766,0.9288239325006898,0.7352589652043506,0.1671890979807784,-0.7375254082962495,-0.9980258810705163,0.7348133801500372 +91,0.3187827016299472,-0.7455666252196119,0.8831446876992547,-0.7457282259393924,-0.6042345054408356,0.7453377086098061,-0.6874077865650357,-0.7361060545078465,0.8265065655995335,-0.745686796978542,0.4209448578688298,-0.7457100528012172,-0.9624214515534063,0.7456891690394822,-0.10929996284788868,0.7448755043908636,0.9976030393472072,-0.7455419661044554,-0.21360723769338963,-0.7447160685579045,-0.9286053003486313,0.7457033959292864,0.5145577812485159,0.7434285092621767,0.7623730181644319,-0.7457201350550966,-0.7603616445540371,0.7429779732628153 +92,0.22266942537659834,-0.75375966505719,0.943142543354768,-0.7539230416090561,-0.43413573802021466,0.7535282328802436,-0.8453833000935627,-0.7441951320299107,0.6237572810738097,-0.7538811573848997,0.7067989178973727,-0.7539046687660658,-0.7820441888697754,0.7538835555124436,-0.5316687037109495,0.75306094949406,0.9008962742565361,-0.7537347349627462,0.32798871727500084,-0.7528997616189804,-0.9745420775300782,0.7538979387416962,-0.10956170636210355,0.7515980533200028,0.9990009954983411,-0.7539148618139438,-0.114670422207047,0.7511425663755935 +93,0.12433130982961088,-0.7619527048947682,0.9819594507812137,-0.7621178572787197,-0.24672934871926827,0.7617187571506809,-0.9507969745697455,-0.7522842095519751,0.365289616335644,-0.7620755177912573,0.9069503760168603,-0.7620992847309144,-0.47819934435573913,0.762077941985405,-0.8481791244883738,0.7612463945972564,0.5836186813718782,-0.761927503821037,0.7728440865649445,-0.7610834546800563,-0.6800432692786572,0.7620924815541058,-0.6889983800010406,0.7597675973778291,0.7657831948282388,-0.7621095885727911,0.5925555015181216,0.7593071594883717 +94,0.024750916933274518,-0.7701457447323464,0.998723665922082,-0.7703126729483833,-0.04948663888437182,0.7699092814211184,-0.9970946982774902,-0.7603732870740393,0.07419171809445486,-0.7702698781976149,0.9971299523764483,-0.7702939006957628,-0.0988573380178795,0.7702723284583664,-0.9958121661322209,0.7694318397004526,0.12345088087429217,-0.7701202726793276,0.9897483537755918,-0.7692671477411321,-0.14798578210049618,0.7702870243665155,-0.987439188728639,0.7679371414356551,0.1724055919379234,-0.7703043153316383,0.9818029515085552,0.7674717526011499 +95,-0.07507677894349374,-0.7783387845699246,0.9930587006468118,-0.778507488618047,0.14972894707872506,0.7780998056915559,-0.98139790299603,-0.7684623645961035,-0.22353350536253763,-0.7784642386039725,0.9664029608336226,-0.7784885166606114,0.2960920683172844,0.7784667149313276,-0.9451732344742293,0.7776172848036488,-0.36694200076130923,-0.7783130415376184,0.9147253979584918,-0.7774508408022079,0.43576739634305206,0.7784815671789254,-0.883170300397547,0.7761066854934812,-0.5020570547524827,-0.7784990420904855,0.8441930805744254,0.7756363457139281 +96,-0.17415433246214224,-0.7865318244075028,0.9650917778625451,-0.7867023042877107,0.34297531241880874,0.7862903299619933,-0.9046825394297735,-0.7765514421181677,-0.5012911465244279,-0.7866585990103301,0.8184951896675896,-0.7866831326254599,0.644295047539256,0.7866611014042889,-0.7063448342750308,0.7858027299068452,-0.7674946830608699,-0.7865058103959092,0.5699033116449198,-0.7856345338632839,0.8672944861372603,0.786676109991335,-0.4187159462837424,0.7842762295513074,-0.9403944237349158,-0.7866937688493326,0.2535704089646602,0.7838009388267063 +97,-0.27149179345859753,-0.7947248642450809,0.9154509743619085,-0.7948971199573743,0.5225483343739402,0.7944808542324306,-0.7717183973204453,-0.784640519640232,-0.7342699425382554,-0.7948529594166878,0.5713411327075237,-0.7948777485903085,0.890778005518003,0.7948554878772504,-0.3268790840952422,0.7939881750100416,-0.9801378996342864,-0.7946985792541998,0.05698769880193234,-0.7938182269243598,0.9958506597046849,0.7948706528037448,0.21650433833629812,0.7924457736091335,-0.9364496011666118,-0.79488849560818,-0.4731237891439766,0.7919655319394845 +98,-0.36611659819807135,-0.8029179040826591,0.8452511155723047,-0.803091935627038,0.6812890032342108,0.8026713785028682,-0.5907725430230472,-0.7927295971622963,-0.9016585914254851,-0.8030473198230453,0.25490934991793845,-0.8030723645551571,0.9966267028573839,0.8030498743502117,0.11767018608852099,0.8021736201132378,-0.9528091748729423,-0.8028913481124905,-0.47273649027613684,-0.8020019199854357,0.7765275470338924,0.8030651956161545,0.7634271384139677,0.8006153176669596,-0.4920778987410095,-0.8030832223670271,-0.9459292291967651,0.8001301250522627 +99,-0.45708328691040556,-0.8111109439202372,0.7560687389802307,-0.8112867512967016,0.8128688294673042,0.8108619027733056,-0.3730953133607881,-0.8008186746843605,-0.988504763706471,-0.811241680229403,-0.09243135758237851,-0.8112669805200055,0.945129957649275,0.811244260823173,0.5387906402886539,0.8103590652164341,-0.6921995337210113,-0.8110841169707812,-0.8630265996674409,-0.8101856130465115,0.2859410213463566,0.8112597384285644,0.9989996141917504,0.8087848617247857,0.18372572839176105,-0.8112779491258744,-0.9111279863473559,0.8082947181650408 +100,-0.5434829505115776,-0.8193039837578153,0.6499066884981489,-0.8194815669663653,0.9120421406254627,0.8190524270437429,-0.13222082417153896,-0.8089077522064247,-0.9870507494610681,-0.8194360406357606,-0.42856434016658845,-0.8194615964848541,0.7444179736513089,0.8194386472961344,0.8526347555568085,0.8185445103196304,-0.2621153054114707,-0.819276885829072,-0.9987661885361102,-0.8183693061075874,-0.30453292967229934,0.819454281240974,0.8271476768156544,0.8169544057826118,0.7731202746681896,-0.8194726758847216,-0.3873951822627929,0.816459311277819 +101,-0.6244523121228162,-0.8274970235953935,0.5291491349080861,-0.8276763826360289,0.9748552097522915,0.8272429513141804,0.11687451546354176,-0.816996829728489,-0.8974264314513984,-0.8276304010421182,-0.7127319361214913,-0.8276562124497026,0.42617875988050197,0.8276330337690957,0.9967143497239663,0.8267299554228268,0.23214389125365936,-0.8274696546873628,-0.8399187353944454,-0.8265529991686633,-0.7886247668890223,0.8276488240533837,0.31795811484910985,0.8251239498404379,0.9989042754306413,-0.8276674026435688,0.3442225009132947,0.8246239043905973 +102,-0.6991823526478141,-0.8356900634329716,0.3965080325258361,-0.8358711983056926,0.9988038779998057,0.8354334755846179,0.35870316379961587,-0.8250859072505532,-0.7276376830815321,-0.8358247614484757,-0.9104775245682311,-0.8358508284145512,0.04065529074559569,0.835827420242057,0.9423423406077478,0.8349154005260231,0.6695661670386436,-0.8356624235456533,-0.4333364483858338,-0.834736692229739,-0.9972272841544051,0.8358433668657935,-0.32090506911849054,0.833293493898264,0.7548879871482043,-0.835862129402416,0.8911227269311257,0.8327884975033755 +103,-0.7669263942241143,-0.8438831032705498,0.2549622145536561,-0.8440660139753562,0.9829333875393695,0.8436239998550553,0.5782293867611701,-0.8331749847726174,-0.492851227569807,-0.8440191218548333,-0.9978235483589577,-0.8440454443793998,-0.35128675486927924,0.8440218067150185,0.7003445103250038,0.8431008456292193,0.9430552931961209,-0.8438551924039441,0.10105883829224695,-0.8429203852908149,-0.8574696206555485,0.8440379096782032,-0.8288927676441116,0.8414630379560901,0.15583608305980468,-0.8440568561612632,0.9598266592786855,0.8409530906161536 +104,-0.8270075607814866,-0.852076143108128,0.10769049490852028,-0.8522608296450199,0.927876444743412,0.8518145241254926,0.7618041070624366,-0.8412640622946818,-0.21403983973403418,-0.852213482261191,-0.9641789025616748,-0.8522400603442482,-0.6877683459854814,0.8522161931879798,0.3189040293337652,0.8512862907324157,0.985651593376013,-0.8520479612622349,0.6056467240158131,-0.8511040783518908,-0.4181731491058256,0.852232452490613,-0.9988311369934331,0.8496325820139163,-0.5165079658975976,-0.8522515829201105,0.5134662383539936,0.8491176837289318 +105,-0.878825541163884,-0.8602691829457061,-0.041999721085329136,-0.8604556453146835,0.8358279961789783,0.8600050483959302,0.8980135377247918,-0.8493531398167461,0.08389108952076789,-0.8604078426675486,-0.8136231543801795,-0.8604346763090968,-0.9156664379249367,0.860410579660941,-0.1260320920406456,0.859471735835612,0.786926007696378,-0.8602407301205256,0.9315985295647208,-0.8592877714129667,0.16720323434360296,0.8604269953030227,-0.7614138036494987,0.8578021260717424,-0.945930247833764,-0.8604463096789576,-0.20843159698642683,0.85728227684171 +106,-0.921862587240073,-0.8684622227832842,-0.1907467138896283,-0.8686504609843472,0.7104577230146554,0.8681955726663676,0.978388836069306,-0.8574422173388102,0.37432827759741877,-0.8686022030739062,-0.5644118769694264,-0.8686292922739453,-0.9990009329949663,0.8686049661339024,-0.5458744934966553,0.8676571809388083,0.39553349032868806,-0.8684334989788163,0.9827744583211934,-0.8674714644740427,0.6941707175695545,0.8686215381154325,-0.21346724916051735,0.8659716701295685,-0.9304667536458673,-0.8686410364378048,-0.8184803972271192,0.8654468699544882 +107,-0.9556886870718162,-0.8766552626208624,-0.3352099467254821,-0.8768452766540109,0.5567637424127302,0.876386096936805,0.9979326553363491,-0.8655312948608744,0.6313278354799385,-0.8767965634802638,-0.2467630776839223,-0.8768239082387939,-0.9246151467833709,0.8767993526068638,-0.8570301197939464,0.8758426260420045,-0.09269942018419589,-0.8766262678371071,0.7440801211804074,-0.8756551575351185,0.9786443977315177,0.8768160809278422,0.4215381664944424,0.8741412141873947,-0.47739020627423745,-0.8768357631966521,-0.989314395098586,0.8736114630672663 +108,-0.9799658614510169,-0.8848483024584405,-0.47214508682764667,-0.8850400923236745,0.3808733483703954,0.8845766212072423,0.9554298555028464,-0.8736203723829388,0.8319327580719243,-0.8849909238866215,0.10080687354039042,-0.8850185242036425,-0.7042529593378598,0.8849937390798251,-0.9975460824982016,0.8840280711452009,-0.5582362796306857,-0.8848190366953977,0.2859186410454222,-0.8838388505961944,0.9212494339897546,0.8850106237402519,0.8846266587931284,0.8823107582452206,0.2002104145358934,-0.8850304899554993,-0.6292602641934263,0.8817760561800446 +109,-0.9944515408762445,-0.8930413422960187,-0.5984768661640778,-0.8932349079933382,0.18979873584388987,0.8927671454776799,0.8535230548034795,-0.8817094499050031,0.9582236050900936,-0.893185284292979,0.4361535302265275,-0.893213140168491,-0.3727047147310378,0.8931881255527864,-0.9394448391035939,0.8922135162483973,-0.8870974286526992,-0.8930118055536885,-0.2565748155701097,-0.8920225436572703,0.5420355384427306,0.8932051665526617,0.9869357351651403,0.8904803023030468,0.7836489490157829,-0.8932252167143465,0.06846893322872644,0.8899406492928228 +110,-0.9990009892270116,-0.9012343821335969,-0.7113681453260299,-0.9014297236630018,-0.00884255333616546,0.9009576697481173,0.6985483245281671,-0.8897985274270672,0.9989191912961758,-0.9013796446993365,0.7186145762733179,-0.9014077561333395,0.01768540929839667,0.9013825120257478,-0.6942946838838123,0.9003989613515935,-0.9987661885361102,-0.9012045744119791,-0.7233912850782432,-0.900206236718346,-0.026526965142686314,0.9013997093650714,0.686740439155011,0.898649846360873,0.9985251379209477,-0.9014199434731938,0.7294561768076743,0.898105242405601 +111,-0.9935687499192448,-0.9094274219711751,-0.8082836295626861,-0.9096245393326654,-0.20713131781899857,0.9091481940185547,0.5001412427975199,-0.8978876049491313,0.9503843011760645,-0.9095740051056942,0.9139403177845447,-0.909602372098188,0.405283396066496,0.9095768984987092,-0.3109064334604978,0.9085844064547897,-0.865902152277306,-0.9093973432702699,-0.9768428035765745,-0.9083899297794219,-0.585822836618184,0.9095942521774812,0.10647013967439092,0.9068193904186991,0.7437793520762195,-0.909614670232041,0.9990009965741318,0.9062698355183791 +112,-0.9782091000924481,-0.9176204618087531,-0.8870468060354797,-0.9178193550023291,-0.39716241030105964,0.9173387182889923,0.2706378009844662,-0.9059766824711957,0.8169544119150898,-0.9177683655120518,0.9984466151224077,-0.9177969880630366,0.7288960459693473,0.9177712849716705,0.1343850896591849,0.9167698515579861,-0.5210350697477014,-0.9175901121285607,-0.9421736034145309,-0.9165736228404978,-0.9404739370341241,0.9177887949898909,-0.5172221327069243,0.9149889344765252,0.1392225150770804,-0.9178093969908881,0.7324596415666184,0.9144344286311573 +113,-0.9530755082904647,-0.9258135016463314,-0.9458888236041615,-0.9260141706719928,-0.5713598908028816,0.9255292425594296,0.02430741151908543,-0.91406575999326,0.6105484181332109,-0.9259627259184094,0.9618866931770942,-0.9259916040278852,0.9374320371841065,0.9259656714446319,0.5529197626265396,0.9249552966611825,-0.04860043041052775,-0.9257828809868514,-0.6296093983195374,-0.9247573159015737,-0.9665904336389802,0.9259833378023007,-0.9299744598723458,0.9231584785343514,-0.5308128461746043,-0.9260041237497354,0.07286413669307466,0.9225990217439355 +114,-0.9184191010545203,-0.9340065414839094,-0.9834882174100049,-0.9342089863416565,-0.7227790554887887,0.933719766829867,-0.22353429506351763,-0.9221548375153242,0.34960395252605325,-0.934157086324767,0.8086936097207147,-0.9341862199927337,0.9979681219885344,0.9341600579175932,0.8613649065220053,0.9331407417643787,0.43573328929041,-0.933975649845142,-0.13134129935853708,-0.9329410089626496,-0.6550490827923521,0.9341778806147104,-0.9634530684306402,0.9313280225921775,-0.9511986316908706,-0.9341988505085826,-0.6258318860497718,0.9307636148567138 +115,-0.8745861537498308,-0.9421995813214876,-0.9990005861276121,-0.94240380201132,-0.8453833000935627,0.9419102911003044,-0.4574777218498348,-0.9302439150373883,0.057430407048144685,-0.9423514467311247,0.5574427268742794,-0.9423808359575822,0.9009469836597984,0.9423544443905545,0.9983073056655124,0.941326186867575,0.8133843030433217,-0.9421684187034328,0.4056660413949089,-0.9411247020237254,-0.11468024144592093,0.94237242342712,-0.6040042970076678,0.9394975666500036,-0.9242208378343082,-0.9423935772674298,-0.9886925863109296,0.938928207969492 +116,-0.822014630696619,-0.9503926211590659,-0.9920775553982699,-0.9505986176809836,-0.9342847802848783,0.950100815370742,-0.6629773996488669,-0.9383329925594527,-0.2398732256491951,-0.9505458071374823,0.23859936348114957,-0.9505754519224306,0.6616861266386566,0.9505488308635158,0.9364809347660296,0.9495116319707713,0.9918904716419369,-0.9503611875617235,0.8230217954704712,-0.9493083950848014,0.4657497076088322,0.9505669662395299,0.0017769982270118311,0.9476671107078298,-0.4625675425953155,-0.950588304026277,-0.8209988342337814,0.9470928010822701 +117,-0.7612298091753971,-0.9585856609966439,-0.9628746015669893,-0.9587934333506474,-0.9859392745923515,0.9582913396411793,-0.8272563538165842,-0.946422070081517,-0.5157496975020239,-0.9587401675438398,-0.10917526416038911,-0.9587700678872793,0.31795957937964214,0.9587432173364773,0.6881957825716858,0.9576970770739677,0.9275472593930406,-0.9585539564200142,0.9976264842611329,-0.9574920881458773,0.8834798840926075,0.9587615090519396,0.6068335760048055,0.9558366547656557,0.21663849574348615,-0.9587830307851243,-0.21273883042355266,0.9552573941950483 +118,-0.6928390310300134,-0.9667787008342221,-0.9120475600189895,-0.9669882490203111,-0.9982874813331263,0.9664818639116167,-0.9401005146550224,-0.9545111476035812,-0.7455557849091524,-0.9669345279501975,-0.44371189162149094,-0.9669646838521279,-0.07596579416635837,0.9669376038094386,0.3028868617699809,0.9658825221771639,0.6361081287031403,-0.966746725278305,0.8779802879067796,-0.965675781206953,0.9925851191869678,0.9669560518643494,0.9644037583190236,0.9640061988234819,0.7939560644642447,-0.9669777575439714,0.5096815658374838,0.9634219873078265 +119,-0.6175256343087134,-0.9749717406718003,-0.840737896530836,-0.9751830646899747,-0.9708371164714672,0.974672388182054,-0.9944937787950584,-0.9626002251256454,-0.9087635945027708,-0.9751288883565551,-0.7244464225498618,-0.9751592998169764,-0.45789783914980475,0.9751319902823998,-0.14272858852872933,0.9740679672803602,0.18892754306014506,-0.9749394941365956,0.4993729663896582,-0.9738594742680289,0.7549518152961344,0.975150594676759,0.928658838510381,0.972175742881308,0.9978636901617591,-0.9751724843028188,0.9585954872104855,0.9715865804206046 +120,-0.5360421255762363,-0.9831647805093785,-0.7505470724099633,-0.9833778803596382,-0.9046825394297735,0.9828629124524917,-0.9870542365219581,-0.9706893026477097,-0.9907942587257301,-0.9833232487629128,-0.9173385109049929,-0.983353915781825,-0.7675378835918268,0.9833263767553613,-0.559925949697151,0.9822534123835565,-0.3045090942023887,-0.9831322629948863,-0.026524888905216166,-0.9820431673291049,0.25359212222226457,0.9833451374891689,0.5141767531159706,0.9803452869391341,0.7324604303259336,-0.9833672110616659,0.8931057296516948,0.9797511735333828 +121,-0.44920266111697,-0.9913578203469565,-0.6435005791262448,-0.991572696029302,-0.8024611244318575,0.991053436722929,-0.9182444425414257,-0.978778380169774,-0.9843202226509236,-0.9915176091692702,-0.9989991086265334,-0.9915485317466733,-0.9560005728421136,0.9915207632283226,-0.86563880934501,0.9904388574867529,-0.7233912850782432,-0.9913250318531771,-0.5445992028628601,-0.9902268603901808,-0.3363545950350067,0.9915396803015786,-0.11000327301802088,0.9885148309969604,0.12256958509788746,-0.9915619378205132,0.3483555550166687,0.9879157666461612 +122,-0.3578749121544809,-0.9995508601845348,-0.5220024501402831,-0.9997675116989656,-0.6682481167153574,0.9992439609933664,-0.7923426565683493,-0.9868674576918381,-0.8899197926396045,-0.9997119695756278,-0.9595264947002462,-0.999743147711522,-0.9935317921867435,0.999715149701284,-0.9989979654203176,0.9986243025899492,-0.9651620602138787,-0.9995178007114678,-0.9020434613640808,-0.9984105534512567,-0.8088029752641013,0.9997342231139884,-0.6893203999899893,0.9966843750547865,-0.544967651204293,-0.9997566645793603,-0.383329965619578,0.9960803597589393 diff --git a/python/cuml/tsa/arima.pxd b/python/cuml/tsa/arima.pxd index 12095ed20e..14c4286e55 100644 --- a/python/cuml/tsa/arima.pxd +++ b/python/cuml/tsa/arima.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,11 +16,12 @@ cdef extern from "cuml/tsa/arima_common.h" namespace "ML": ctypedef struct ARIMAOrder: - int p # Basic order + int p # Basic order int d int q - int P # Seasonal order + int P # Seasonal order int D int Q - int s # Seasonal period - int k # Fit intercept? + int s # Seasonal period + int k # Fit intercept? + int n_exog # Number of exogenous regressors diff --git a/python/cuml/tsa/arima.pyx b/python/cuml/tsa/arima.pyx index a92be021b3..80e7845415 100644 --- a/python/cuml/tsa/arima.pyx +++ b/python/cuml/tsa/arima.pyx @@ -42,6 +42,7 @@ from cuml.internals import _deprecate_pos_args cdef extern from "cuml/tsa/arima_common.h" namespace "ML": cdef cppclass ARIMAParams[DataT]: DataT* mu + DataT* beta DataT* ar DataT* ma DataT* sar @@ -76,37 +77,40 @@ cdef extern from "cuml/tsa/batched_arima.hpp" namespace "ML": void batched_loglike( handle_t& handle, const ARIMAMemory[double]& arima_mem, - const double* y, int batch_size, int nobs, const ARIMAOrder& order, - const double* params, double* loglike, bool trans, bool host_loglike, - LoglikeMethod method, int truncate) + const double* y, const double* d_exog, int batch_size, int nobs, + const ARIMAOrder& order, const double* params, double* loglike, + bool trans, bool host_loglike, LoglikeMethod method, int truncate) void batched_loglike( handle_t& handle, const ARIMAMemory[double]& arima_mem, - const double* y, int batch_size, int n_obs, const ARIMAOrder& order, - const ARIMAParams[double]& params, double* loglike, bool trans, - bool host_loglike, LoglikeMethod method, int truncate) + const double* y, const double* d_exog, int batch_size, int n_obs, + const ARIMAOrder& order, const ARIMAParams[double]& params, + double* loglike, bool trans, bool host_loglike, LoglikeMethod method, + int truncate) void batched_loglike_grad( handle_t& handle, const ARIMAMemory[double]& arima_mem, - const double* d_y, int batch_size, int nobs, const ARIMAOrder& order, - const double* d_x, double* d_grad, double h, bool trans, - LoglikeMethod method, int truncate) + const double* d_y, const double* d_exog, int batch_size, int nobs, + const ARIMAOrder& order, const double* d_x, double* d_grad, double h, + bool trans, LoglikeMethod method, int truncate) void cpp_predict "predict" ( handle_t& handle, const ARIMAMemory[double]& arima_mem, - const double* d_y, int batch_size, int nobs, int start, int end, - const ARIMAOrder& order, const ARIMAParams[double]& params, - double* d_y_p, bool pre_diff, double level, double* d_lower, - double* d_upper) + const double* d_y, const double* d_exog, const double* d_exog_fut, + int batch_size, int nobs, int start, int end, const ARIMAOrder& order, + const ARIMAParams[double]& params, double* d_y_p, bool pre_diff, + double level, double* d_lower, double* d_upper) void information_criterion( handle_t& handle, const ARIMAMemory[double]& arima_mem, - const double* d_y, int batch_size, int nobs, const ARIMAOrder& order, - const ARIMAParams[double]& params, double* ic, int ic_type) + const double* d_y, const double* d_exog, int batch_size, int nobs, + const ARIMAOrder& order, const ARIMAParams[double]& params, + double* ic, int ic_type) void estimate_x0( handle_t& handle, ARIMAParams[double]& params, const double* d_y, - int batch_size, int nobs, const ARIMAOrder& order, bool missing) + const double* d_exog, int batch_size, int nobs, + const ARIMAOrder& order, bool missing) cdef extern from "cuml/tsa/batched_kalman.hpp" namespace "ML": @@ -126,6 +130,8 @@ cdef class ARIMAParamsWrapper: cdef uintptr_t d_mu_ptr = \ model.mu_.ptr if order.k else NULL + cdef uintptr_t d_beta_ptr = \ + model.beta_.ptr if order.n_exog else NULL cdef uintptr_t d_ar_ptr = \ model.ar_.ptr if order.p else NULL cdef uintptr_t d_ma_ptr = \ @@ -137,6 +143,7 @@ cdef class ARIMAParamsWrapper: cdef uintptr_t d_sigma2_ptr = model.sigma2_.ptr self.params.mu = d_mu_ptr + self.params.beta = d_beta_ptr self.params.ar = d_ar_ptr self.params.ma = d_ma_ptr self.params.sar = d_sar_ptr @@ -161,7 +168,7 @@ class ARIMA(Base): Parameters ---------- endog : dataframe or array-like (device or host) - The time series data, assumed to have each time series in columns. + Endogenous variable, assumed to have each time series in columns. Acceptable formats: cuDF DataFrame, cuDF Series, NumPy ndarray, Numba device ndarray, cuda array interface compliant array like CuPy. Missing values are accepted, represented by NaN. @@ -169,6 +176,13 @@ class ARIMA(Base): The ARIMA order (p, d, q) of the model seasonal_order: Tuple[int, int, int, int] The seasonal ARIMA order (P, D, Q, s) of the model + exog : dataframe or array-like (device or host) + Exogenous variables, assumed to have each time series in columns, + such that variables associated with a same batch member are adjacent + (number of columns: n_exog * batch_size) + Acceptable formats: cuDF DataFrame, cuDF Series, NumPy ndarray, + Numba device ndarray, cuda array interface compliant array like CuPy. + Missing values are not supported. fit_intercept : bool or int (default = True) Whether to include a constant trend mu in the model simple_differencing: bool or int (default = True) @@ -199,7 +213,7 @@ class ARIMA(Base): Attributes ---------- order : ARIMAOrder - The ARIMA order of the model (p, d, q, P, D, Q, s, k) + The ARIMA order of the model (p, d, q, P, D, Q, s, k, n_exog) d_y: device array Time series data on device n_obs: int @@ -284,6 +298,7 @@ class ARIMA(Base): _temp_mem = CumlArrayDescriptor() mu_ = CumlArrayDescriptor() + beta_ = CumlArrayDescriptor() ar_ = CumlArrayDescriptor() ma_ = CumlArrayDescriptor() sar_ = CumlArrayDescriptor() @@ -296,6 +311,7 @@ class ARIMA(Base): *, order: Tuple[int, int, int] = (1, 1, 1), seasonal_order: Tuple[int, int, int, int] = (0, 0, 0, 0), + exog=None, fit_intercept=True, simple_differencing=True, handle=None, @@ -313,13 +329,6 @@ class ARIMA(Base): output_type=output_type) self._set_base_attributes(output_type=endog) - # Set the ARIMA order - cdef ARIMAOrder cpp_order - cpp_order.p, cpp_order.d, cpp_order.q = order - cpp_order.P, cpp_order.D, cpp_order.Q, cpp_order.s = seasonal_order - cpp_order.k = int(fit_intercept) - self.order = cpp_order - # Check validity of the ARIMA order and seasonal order p, d, q = order P, D, Q, s = seasonal_order @@ -330,7 +339,7 @@ class ARIMA(Base): raise ValueError("ERROR: Invalid order. Required: d+D <= 2") if s != 0 and (p >= s or q >= s): raise ValueError("ERROR: Invalid order. Required: s > p, s > q") - if p + q + P + Q + cpp_order.k == 0: + if p + q + P + Q + int(fit_intercept) == 0: raise ValueError("ERROR: Invalid order. At least one parameter" " among p, q, P, Q and fit_intercept must be" " non-zero") @@ -340,7 +349,7 @@ class ARIMA(Base): raise ValueError("ERROR: Invalid order. " "Required: max(p+s*P, q+s*Q) <= 1024") - # Get device array. Float64 only for now. + # Endogenous variable. Float64 only for now. self.d_y, self.n_obs, self.batch_size, self.dtype \ = input_to_cuml_array(endog, check_dtype=np.float64) @@ -348,10 +357,38 @@ class ARIMA(Base): raise ValueError("ERROR: Number of observations too small for the" " given order") + # Exogenous variables + if exog is not None: + self.d_exog, n_obs_exog, n_cols_exog, _ \ + = input_to_cuml_array(exog, check_dtype=np.float64) + + if n_cols_exog % self.batch_size != 0: + raise ValueError("Number of columns in exog is not a multiple" + " of batch_size") + if n_obs_exog != self.n_obs: + raise ValueError("Number of observations mismatch between" + " endog and exog") + + n_exog = n_cols_exog // self.batch_size + else: + n_exog = 0 + + # Set the ARIMA order + cdef ARIMAOrder cpp_order + cpp_order.p, cpp_order.d, cpp_order.q = order + cpp_order.P, cpp_order.D, cpp_order.Q, cpp_order.s = seasonal_order + cpp_order.k = int(fit_intercept) + cpp_order.n_exog = n_exog + self.order = cpp_order + self.simple_differencing = simple_differencing self._d_y_diff = CumlArray.empty( (self.n_obs - d - s * D, self.batch_size), self.dtype) + if n_exog > 0: + self._d_exog_diff = CumlArray.empty( + (self.n_obs - d - s * D, self.batch_size * n_exog), + self.dtype) self.n_obs_diff = self.n_obs - d - D * s @@ -371,6 +408,8 @@ class ARIMA(Base): cdef uintptr_t d_y_ptr = self.d_y.ptr cdef uintptr_t d_y_diff_ptr = self._d_y_diff.ptr + cdef uintptr_t d_exog_ptr + cdef uintptr_t d_exog_diff_ptr cdef handle_t* handle_ = self.handle.getHandle() cdef ARIMAOrder cpp_order_diff = self.order @@ -391,6 +430,14 @@ class ARIMA(Base): cpp_order_diff.d = 0 cpp_order_diff.D = 0 self.order_diff = cpp_order_diff + + if cpp_order_diff.n_exog > 0: + d_exog_ptr = self.d_exog.ptr + d_exog_diff_ptr = self._d_exog_diff.ptr + batched_diff(handle_[0], d_exog_diff_ptr, + d_exog_ptr, + self.batch_size * cpp_order_diff.n_exog, + self.n_obs, self.order) else: self.order_diff = None @@ -422,6 +469,11 @@ class ARIMA(Base): cdef uintptr_t d_y_kf_ptr = \ self._d_y_diff.ptr if self.simple_differencing else self.d_y.ptr + cdef uintptr_t d_exog_kf_ptr = NULL + if order.n_exog: + d_exog_kf_ptr = (self._d_exog_diff.ptr if self.simple_differencing + else self.d_exog.ptr) + n_obs_kf = (self.n_obs_diff if self.simple_differencing else self.n_obs) @@ -438,9 +490,10 @@ class ARIMA(Base): d_temp_mem) information_criterion(handle_[0], arima_mem_ptr[0], - d_y_kf_ptr, self.batch_size, - n_obs_kf, order_kf, cpp_params, - d_ic_ptr, ic_type_id) + d_y_kf_ptr, d_exog_kf_ptr, + self.batch_size, n_obs_kf, order_kf, + cpp_params, d_ic_ptr, + ic_type_id) del arima_mem_ptr @@ -465,7 +518,8 @@ class ARIMA(Base): def complexity(self): """Model complexity (number of parameters)""" cdef ARIMAOrder order = self.order - return order.p + order.P + order.q + order.Q + order.k + 1 + return (order.p + order.P + order.q + order.Q + order.k + order.n_exog + + 1) @cuml.internals.api_base_return_autoarray(input_arg=None) def get_fit_params(self) -> Dict[str, CumlArray]: @@ -483,8 +537,9 @@ class ARIMA(Base): """ cdef ARIMAOrder order = self.order params = dict() - names = ["mu", "ar", "ma", "sar", "sma", "sigma2"] - criteria = [order.k, order.p, order.q, order.P, order.Q, True] + names = ["mu", "beta", "ar", "ma", "sar", "sma", "sigma2"] + criteria = [order.k, order.n_exog, order.p, order.q, order.P, order.Q, + True] for i in range(len(names)): if criteria[i] > 0: params[names[i]] = getattr(self, "{}_".format(names[i])) @@ -504,7 +559,7 @@ class ARIMA(Base): (n, batch_size) for any other type, where n is the corresponding number of parameters of this type. """ - for param_name in ["mu", "ar", "ma", "sar", "sma", "sigma2"]: + for param_name in ["mu", "beta", "ar", "ma", "sar", "sma", "sigma2"]: if param_name in params: array, *_ = input_to_cuml_array(params[param_name], check_dtype=np.float64) @@ -545,7 +600,8 @@ class ARIMA(Base): self, start=0, end=None, - level=None + level=None, + exog=None, ) -> Union[CumlArray, Tuple[CumlArray, CumlArray, CumlArray]]: """Compute in-sample and/or out-of-sample prediction for each series @@ -559,6 +615,11 @@ class ARIMA(Base): level: float or None (default = None) Confidence level for prediction intervals, or None to return only the point forecasts. ``0 < level < 1`` + exog : dataframe or array-like (device or host) + Future values for exogenous variables. Assumed to have each time + series in columns, such that variables associated with a same + batch member are adjacent. + Shape = (end - n_obs, n_exog * batch_size) Returns ------- @@ -607,9 +668,34 @@ class ARIMA(Base): if end is None: end = self.n_obs + if order.n_exog > 0 and end > self.n_obs and exog is None: + raise ValueError("The model was fit with a regression component," + " so future values must be provided via `exog`") + elif order.n_exog == 0 and exog is not None: + raise ValueError("A value was given for `exog` but the model was" + " fit without any regression component") + elif end <= self.n_obs and exog is not None: + raise ValueError("A value was given for `exog` but only in-sample" + " predictions were requested") + cdef handle_t* handle_ = self.handle.getHandle() predict_size = end - start + # Future values of the exogenous variables + cdef uintptr_t d_exog_fut_ptr = NULL + if order.n_exog and end > self.n_obs: + d_exog_fut, n_obs_fut, n_cols_fut, _ \ + = input_to_cuml_array(exog, check_dtype=np.float64) + if n_obs_fut != end - self.n_obs: + raise ValueError( + "Dimensions mismatch: `exog` should contain {}" + " observations per column".format(end - self.n_obs)) + elif n_cols_fut != self.batch_size * order.n_exog: + raise ValueError( + "Dimensions mismatch: `exog` should have {} columns" + .format(self.batch_size * order.n_exog)) + d_exog_fut_ptr = d_exog_fut.ptr + # allocate predictions and intervals device memory cdef uintptr_t d_y_p_ptr = NULL cdef uintptr_t d_lower_ptr = NULL @@ -626,6 +712,9 @@ class ARIMA(Base): d_upper_ptr = d_upper.ptr cdef uintptr_t d_y_ptr = self.d_y.ptr + cdef uintptr_t d_exog_ptr = NULL + if order.n_exog: + d_exog_ptr = self.d_exog.ptr cdef uintptr_t d_temp_mem = self._temp_mem.ptr arima_mem_ptr = new ARIMAMemory[double]( @@ -633,6 +722,7 @@ class ARIMA(Base): d_temp_mem) cpp_predict(handle_[0], arima_mem_ptr[0], d_y_ptr, + d_exog_ptr, d_exog_fut_ptr, self.batch_size, self.n_obs, start, end, order, cpp_params, d_y_p_ptr, self.simple_differencing, @@ -653,7 +743,8 @@ class ARIMA(Base): def forecast( self, nsteps: int, - level=None + level=None, + exog=None ) -> Union[CumlArray, Tuple[CumlArray, CumlArray, CumlArray]]: """Forecast the given model `nsteps` into the future. @@ -664,6 +755,11 @@ class ARIMA(Base): level: float or None (default = None) Confidence level for prediction intervals, or None to return only the point forecasts. 0 < level < 1 + exog : dataframe or array-like (device or host) + Future values for exogenous variables. Assumed to have each time + series in columns, such that variables associated with a same + batch member are adjacent. + Shape = (nsteps, n_exog * batch_size) Returns ------- @@ -687,7 +783,7 @@ class ARIMA(Base): y_fc = model.forecast(10) """ - return self.predict(self.n_obs, self.n_obs + nsteps, level) + return self.predict(self.n_obs, self.n_obs + nsteps, level, exog) @cuml.internals.api_base_return_any_skipall def _create_arrays(self): @@ -696,6 +792,9 @@ class ARIMA(Base): if order.k and not hasattr(self, "mu_"): self.mu_ = CumlArray.empty(self.batch_size, np.float64) + if order.n_exog and not hasattr(self, "beta_"): + self.beta_ = CumlArray.empty((order.n_exog, self.batch_size), + np.float64) if order.p and not hasattr(self, "ar_"): self.ar_ = CumlArray.empty((order.p, self.batch_size), np.float64) @@ -723,12 +822,15 @@ class ARIMA(Base): cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params cdef uintptr_t d_y_ptr = self.d_y.ptr + cdef uintptr_t d_exog_ptr = NULL + if order.n_exog: + d_exog_ptr = self.d_exog.ptr cdef handle_t* handle_ = self.handle.getHandle() # Call C++ function estimate_x0(handle_[0], cpp_params, d_y_ptr, - self.batch_size, self.n_obs, order, - self.missing) + d_exog_ptr, self.batch_size, + self.n_obs, order, self.missing) @cuml.internals.api_base_return_any_skipall def fit(self, @@ -869,6 +971,10 @@ class ARIMA(Base): cdef uintptr_t d_y_kf_ptr = \ self._d_y_diff.ptr if diff else self.d_y.ptr + cdef uintptr_t d_exog_kf_ptr = NULL + if order.n_exog: + d_exog_kf_ptr = self._d_exog_diff.ptr if diff else self.d_exog.ptr + cdef handle_t* handle_ = self.handle.getHandle() n_obs_kf = (self.n_obs_diff if diff else self.n_obs) @@ -879,9 +985,10 @@ class ARIMA(Base): d_temp_mem) batched_loglike(handle_[0], arima_mem_ptr[0], d_y_kf_ptr, - self.batch_size, n_obs_kf, order_kf, - d_x_ptr, vec_loglike.data(), - trans, True, ll_method, truncate) + d_exog_kf_ptr, self.batch_size, + n_obs_kf, order_kf, d_x_ptr, + vec_loglike.data(), trans, + True, ll_method, truncate) del arima_mem_ptr @@ -936,6 +1043,10 @@ class ARIMA(Base): cdef uintptr_t d_y_kf_ptr = \ self._d_y_diff.ptr if diff else self.d_y.ptr + cdef uintptr_t d_exog_kf_ptr = NULL + if order.n_exog: + d_exog_kf_ptr = self._d_exog_diff.ptr if diff else self.d_exog.ptr + cdef handle_t* handle_ = self.handle.getHandle() cdef uintptr_t d_temp_mem = self._temp_mem.ptr @@ -944,7 +1055,8 @@ class ARIMA(Base): d_temp_mem) batched_loglike_grad(handle_[0], arima_mem_ptr[0], - d_y_kf_ptr, self.batch_size, + d_y_kf_ptr, d_exog_kf_ptr, + self.batch_size, (self.n_obs_diff if diff else self.n_obs), order_kf, d_x_ptr, d_grad, h, trans, ll_method, @@ -975,6 +1087,11 @@ class ARIMA(Base): cdef uintptr_t d_y_kf_ptr = \ self._d_y_diff.ptr if self.simple_differencing else self.d_y.ptr + cdef uintptr_t d_exog_kf_ptr = NULL + if order.n_exog: + d_exog_kf_ptr = (self._d_exog_diff.ptr if self.simple_differencing + else self.d_exog.ptr) + n_obs_kf = (self.n_obs_diff if self.simple_differencing else self.n_obs) @@ -987,9 +1104,10 @@ class ARIMA(Base): d_temp_mem) batched_loglike(handle_[0], arima_mem_ptr[0], d_y_kf_ptr, - self.batch_size, n_obs_kf, order_kf, - cpp_params, vec_loglike.data(), - False, True, ll_method, 0) + d_exog_kf_ptr, self.batch_size, + n_obs_kf, order_kf, cpp_params, + vec_loglike.data(), False, + True, ll_method, 0) del arima_mem_ptr