Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add level-3 BLAS triangular Sylvester equation solver #651

Merged
merged 8 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions LAPACKE/include/lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -22002,6 +22002,84 @@ void LAPACK_ztrsyl_base(
#define LAPACK_ztrsyl(...) LAPACK_ztrsyl_base(__VA_ARGS__)
#endif

#define LAPACK_ctrsyl3_base LAPACK_GLOBAL(ctrsyl3,CTRSYL3)
void LAPACK_ctrsyl3_base(
char const* trana, char const* tranb,
lapack_int const* isgn, lapack_int const* m, lapack_int const* n,
lapack_complex_float const* A, lapack_int const* lda,
lapack_complex_float const* B, lapack_int const* ldb,
lapack_complex_float* C, lapack_int const* ldc, float* scale,
float* swork, lapack_int const *ldswork,
lapack_int* info
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t, size_t
#endif
);
#ifdef LAPACK_FORTRAN_STRLEN_END
#define LAPACK_ctrsyl3(...) LAPACK_ctrsyl3_base(__VA_ARGS__, 1, 1)
#else
#define LAPACK_ctrsyl3(...) LAPACK_ctrsyl3_base(__VA_ARGS__)
#endif

#define LAPACK_dtrsyl3_base LAPACK_GLOBAL(dtrsyl3,DTRSYL3)
void LAPACK_dtrsyl3_base(
char const* trana, char const* tranb,
lapack_int const* isgn, lapack_int const* m, lapack_int const* n,
double const* A, lapack_int const* lda,
double const* B, lapack_int const* ldb,
double* C, lapack_int const* ldc, double* scale,
lapack_int* iwork, lapack_int const* liwork,
double* swork, lapack_int const *ldswork,
lapack_int* info
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t, size_t
#endif
);
#ifdef LAPACK_FORTRAN_STRLEN_END
#define LAPACK_dtrsyl3(...) LAPACK_dtrsyl3_base(__VA_ARGS__, 1, 1)
#else
#define LAPACK_dtrsyl3(...) LAPACK_dtrsyl3_base(__VA_ARGS__)
#endif

#define LAPACK_strsyl3_base LAPACK_GLOBAL(strsyl3,STRSYL3)
void LAPACK_strsyl3_base(
char const* trana, char const* tranb,
lapack_int const* isgn, lapack_int const* m, lapack_int const* n,
float const* A, lapack_int const* lda,
float const* B, lapack_int const* ldb,
float* C, lapack_int const* ldc, float* scale,
lapack_int* iwork, lapack_int const* liwork,
float* swork, lapack_int const *ldswork,
lapack_int* info
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t, size_t
#endif
);
#ifdef LAPACK_FORTRAN_STRLEN_END
#define LAPACK_strsyl3(...) LAPACK_strsyl3_base(__VA_ARGS__, 1, 1)
#else
#define LAPACK_strsyl3(...) LAPACK_strsyl3_base(__VA_ARGS__)
#endif

#define LAPACK_ztrsyl3_base LAPACK_GLOBAL(ztrsyl3,ZTRSYL3)
void LAPACK_ztrsyl3_base(
char const* trana, char const* tranb,
lapack_int const* isgn, lapack_int const* m, lapack_int const* n,
lapack_complex_double const* A, lapack_int const* lda,
lapack_complex_double const* B, lapack_int const* ldb,
lapack_complex_double* C, lapack_int const* ldc, double* scale,
double* swork, lapack_int const *ldswork,
lapack_int* info
#ifdef LAPACK_FORTRAN_STRLEN_END
, size_t, size_t
#endif
);
#ifdef LAPACK_FORTRAN_STRLEN_END
#define LAPACK_ztrsyl3(...) LAPACK_ztrsyl3_base(__VA_ARGS__, 1, 1)
#else
#define LAPACK_ztrsyl3(...) LAPACK_ztrsyl3_base(__VA_ARGS__)
#endif

#define LAPACK_ctrtri_base LAPACK_GLOBAL(ctrtri,CTRTRI)
void LAPACK_ctrtri_base(
char const* uplo, char const* diag,
Expand Down
39 changes: 39 additions & 0 deletions LAPACKE/include/lapacke.h
Original file line number Diff line number Diff line change
Expand Up @@ -4477,6 +4477,23 @@ lapack_int LAPACKE_ztrsyl( int matrix_layout, char trana, char tranb,
lapack_complex_double* c, lapack_int ldc,
double* scale );

lapack_int LAPACKE_strsyl3( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const float* a, lapack_int lda, const float* b,
lapack_int ldb, float* c, lapack_int ldc,
float* scale );
lapack_int LAPACKE_dtrsyl3( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const double* a, lapack_int lda, const double* b,
lapack_int ldb, double* c, lapack_int ldc,
double* scale );
lapack_int LAPACKE_ztrsyl3( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const lapack_complex_double* a, lapack_int lda,
const lapack_complex_double* b, lapack_int ldb,
lapack_complex_double* c, lapack_int ldc,
double* scale );

lapack_int LAPACKE_strtri( int matrix_layout, char uplo, char diag, lapack_int n,
float* a, lapack_int lda );
lapack_int LAPACKE_dtrtri( int matrix_layout, char uplo, char diag, lapack_int n,
Expand Down Expand Up @@ -10174,6 +10191,28 @@ lapack_int LAPACKE_ztrsyl_work( int matrix_layout, char trana, char tranb,
lapack_complex_double* c, lapack_int ldc,
double* scale );

lapack_int LAPACKE_strsyl3_work( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const float* a, lapack_int lda,
const float* b, lapack_int ldb,
float* c, lapack_int ldc, float* scale,
lapack_int* iwork, lapack_int liwork,
float* swork, lapack_int ldswork );
lapack_int LAPACKE_dtrsyl3_work( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const double* a, lapack_int lda,
const double* b, lapack_int ldb,
double* c, lapack_int ldc, double* scale,
lapack_int* iwork, lapack_int liwork,
double* swork, lapack_int ldswork );
lapack_int LAPACKE_ztrsyl3_work( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const lapack_complex_double* a, lapack_int lda,
const lapack_complex_double* b, lapack_int ldb,
lapack_complex_double* c, lapack_int ldc,
double* scale, double* swork,
lapack_int ldswork );

lapack_int LAPACKE_strtri_work( int matrix_layout, char uplo, char diag,
lapack_int n, float* a, lapack_int lda );
lapack_int LAPACKE_dtrtri_work( int matrix_layout, char uplo, char diag,
Expand Down
8 changes: 8 additions & 0 deletions LAPACKE/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ lapacke_ctrsna.c
lapacke_ctrsna_work.c
lapacke_ctrsyl.c
lapacke_ctrsyl_work.c
lapacke_ctrsyl3.c
lapacke_ctrsyl3_work.c
lapacke_ctrtri.c
lapacke_ctrtri_work.c
lapacke_ctrtrs.c
Expand Down Expand Up @@ -1169,6 +1171,8 @@ lapacke_dtrsna.c
lapacke_dtrsna_work.c
lapacke_dtrsyl.c
lapacke_dtrsyl_work.c
lapacke_dtrsyl3.c
lapacke_dtrsyl3_work.c
lapacke_dtrtri.c
lapacke_dtrtri_work.c
lapacke_dtrtrs.c
Expand Down Expand Up @@ -1740,6 +1744,8 @@ lapacke_strsna.c
lapacke_strsna_work.c
lapacke_strsyl.c
lapacke_strsyl_work.c
lapacke_strsyl3.c
lapacke_strsyl3_work.c
lapacke_strtri.c
lapacke_strtri_work.c
lapacke_strtrs.c
Expand Down Expand Up @@ -2314,6 +2320,8 @@ lapacke_ztrsna.c
lapacke_ztrsna_work.c
lapacke_ztrsyl.c
lapacke_ztrsyl_work.c
lapacke_ztrsyl3.c
lapacke_ztrsyl3_work.c
lapacke_ztrtri.c
lapacke_ztrtri_work.c
lapacke_ztrtrs.c
Expand Down
10 changes: 9 additions & 1 deletion LAPACKE/src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ include $(TOPSRCDIR)/make.inc

.SUFFIXES: .c .o
.c.o:
$(CC) $(CFLAGS) -I../include -c -o $@ $<
$(CC) $(CFLAGS) -Wall -I../include -c -o $@ $<

OBJ = \
lapacke_ilaver.o \
Expand Down Expand Up @@ -604,6 +604,8 @@ lapacke_ctrsna.o \
lapacke_ctrsna_work.o \
lapacke_ctrsyl.o \
lapacke_ctrsyl_work.o \
lapacke_ctrsyl3.o \
lapacke_ctrsyl3_work.o \
lapacke_ctrtri.o \
lapacke_ctrtri_work.o \
lapacke_ctrtrs.o \
Expand Down Expand Up @@ -1216,6 +1218,8 @@ lapacke_dtrsna.o \
lapacke_dtrsna_work.o \
lapacke_dtrsyl.o \
lapacke_dtrsyl_work.o \
lapacke_dtrsyl3.o \
lapacke_dtrsyl3_work.o \
lapacke_dtrtri.o \
lapacke_dtrtri_work.o \
lapacke_dtrtrs.o \
Expand Down Expand Up @@ -1782,6 +1786,8 @@ lapacke_strsna.o \
lapacke_strsna_work.o \
lapacke_strsyl.o \
lapacke_strsyl_work.o \
lapacke_strsyl3.o \
lapacke_strsyl3_work.o \
lapacke_strtri.o \
lapacke_strtri_work.o \
lapacke_strtrs.o \
Expand Down Expand Up @@ -2356,6 +2362,8 @@ lapacke_ztrsna.o \
lapacke_ztrsna_work.o \
lapacke_ztrsyl.o \
lapacke_ztrsyl_work.o \
lapacke_ztrsyl3.o \
lapacke_ztrsyl3_work.o \
lapacke_ztrtri.o \
lapacke_ztrtri_work.o \
lapacke_ztrtrs.o \
Expand Down
1 change: 0 additions & 1 deletion LAPACKE/src/lapacke_cgesvdq.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ lapack_int LAPACKE_cgesvdq( int matrix_layout, char joba, char jobp,
lapack_int lrwork = -1;
float* rwork = NULL;
float rwork_query;
lapack_int i;
if( matrix_layout != LAPACK_COL_MAJOR && matrix_layout != LAPACK_ROW_MAJOR ) {
LAPACKE_xerbla( "LAPACKE_cgesvdq", -1 );
return -1;
Expand Down
56 changes: 56 additions & 0 deletions LAPACKE/src/lapacke_ctrsyl3.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "lapacke_utils.h"

lapack_int LAPACKE_ctrsyl3( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const lapack_complex_float* a, lapack_int lda,
const lapack_complex_float* b, lapack_int ldb,
lapack_complex_float* c, lapack_int ldc,
float* scale )
{
lapack_int info = 0;
float swork_query[2];
float* swork = NULL;
lapack_int ldswork = -1;
lapack_int swork_size = -1;
if( matrix_layout != LAPACK_COL_MAJOR && matrix_layout != LAPACK_ROW_MAJOR ) {
LAPACKE_xerbla( "LAPACKE_ctrsyl3", -1 );
return -1;
}
#ifndef LAPACK_DISABLE_NAN_CHECK
if( LAPACKE_get_nancheck() ) {
/* Optionally check input matrices for NaNs */
if( LAPACKE_cge_nancheck( matrix_layout, m, m, a, lda ) ) {
return -7;
}
if( LAPACKE_cge_nancheck( matrix_layout, n, n, b, ldb ) ) {
return -9;
}
if( LAPACKE_cge_nancheck( matrix_layout, m, n, c, ldc ) ) {
return -11;
}
}
#endif
/* Query optimal working array sizes */
info = LAPACKE_ctrsyl3_work( matrix_layout, trana, tranb, isgn, m, n, a, lda,
b, ldb, c, ldc, scale, swork_query, ldswork );
if( info != 0 ) {
goto exit_level_0;
}
ldswork = swork_query[0];
swork_size = ldswork * swork_query[1];
swork = (float*)LAPACKE_malloc( sizeof(float) * swork_size);
if( swork == NULL ) {
info = LAPACK_WORK_MEMORY_ERROR;
goto exit_level_0;
}
/* Call middle-level interface */
info = LAPACKE_ctrsyl3_work( matrix_layout, trana, tranb, isgn, m, n, a,
lda, b, ldb, c, ldc, scale, swork, ldswork );
/* Release memory and exit */
LAPACKE_free( swork );
exit_level_0:
if( info == LAPACK_WORK_MEMORY_ERROR ) {
LAPACKE_xerbla( "LAPACKE_ctrsyl3", info );
}
return info;
}
88 changes: 88 additions & 0 deletions LAPACKE/src/lapacke_ctrsyl3_work.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "lapacke_utils.h"

lapack_int LAPACKE_ctrsyl3_work( int matrix_layout, char trana, char tranb,
lapack_int isgn, lapack_int m, lapack_int n,
const lapack_complex_float* a, lapack_int lda,
const lapack_complex_float* b, lapack_int ldb,
lapack_complex_float* c, lapack_int ldc,
float* scale, float* swork,
lapack_int ldswork )
{
lapack_int info = 0;
if( matrix_layout == LAPACK_COL_MAJOR ) {
/* Call LAPACK function and adjust info */
LAPACK_ctrsyl3( &trana, &tranb, &isgn, &m, &n, a, &lda, b, &ldb, c, &ldc,
scale, swork, &ldswork, &info );
if( info < 0 ) {
info = info - 1;
}
} else if( matrix_layout == LAPACK_ROW_MAJOR ) {
lapack_int lda_t = MAX(1,m);
lapack_int ldb_t = MAX(1,n);
lapack_int ldc_t = MAX(1,m);
lapack_complex_float* a_t = NULL;
lapack_complex_float* b_t = NULL;
lapack_complex_float* c_t = NULL;
/* Check leading dimension(s) */
if( lda < m ) {
info = -8;
LAPACKE_xerbla( "LAPACKE_ctrsyl3_work", info );
return info;
}
if( ldb < n ) {
info = -10;
LAPACKE_xerbla( "LAPACKE_ctrsyl3_work", info );
return info;
}
if( ldc < n ) {
info = -12;
LAPACKE_xerbla( "LAPACKE_ctrsyl3_work", info );
return info;
}
/* Allocate memory for temporary array(s) */
a_t = (lapack_complex_float*)
LAPACKE_malloc( sizeof(lapack_complex_float) * lda_t * MAX(1,m) );
if( a_t == NULL ) {
info = LAPACK_TRANSPOSE_MEMORY_ERROR;
goto exit_level_0;
}
b_t = (lapack_complex_float*)
LAPACKE_malloc( sizeof(lapack_complex_float) * ldb_t * MAX(1,n) );
if( b_t == NULL ) {
info = LAPACK_TRANSPOSE_MEMORY_ERROR;
goto exit_level_1;
}
c_t = (lapack_complex_float*)
LAPACKE_malloc( sizeof(lapack_complex_float) * ldc_t * MAX(1,n) );
if( c_t == NULL ) {
info = LAPACK_TRANSPOSE_MEMORY_ERROR;
goto exit_level_2;
}
/* Transpose input matrices */
LAPACKE_cge_trans( matrix_layout, m, m, a, lda, a_t, lda_t );
LAPACKE_cge_trans( matrix_layout, n, n, b, ldb, b_t, ldb_t );
LAPACKE_cge_trans( matrix_layout, m, n, c, ldc, c_t, ldc_t );
/* Call LAPACK function and adjust info */
LAPACK_ctrsyl3( &trana, &tranb, &isgn, &m, &n, a_t, &lda_t, b_t, &ldb_t,
c_t, &ldc_t, scale, swork, &ldswork, &info );
if( info < 0 ) {
info = info - 1;
}
/* Transpose output matrices */
LAPACKE_cge_trans( LAPACK_COL_MAJOR, m, n, c_t, ldc_t, c, ldc );
/* Release memory and exit */
LAPACKE_free( c_t );
exit_level_2:
LAPACKE_free( b_t );
exit_level_1:
LAPACKE_free( a_t );
exit_level_0:
if( info == LAPACK_TRANSPOSE_MEMORY_ERROR ) {
LAPACKE_xerbla( "LAPACKE_ctrsyl3_work", info );
}
} else {
info = -1;
LAPACKE_xerbla( "LAPACKE_ctrsyl3_work", info );
}
return info;
}
1 change: 0 additions & 1 deletion LAPACKE/src/lapacke_dgesvdq.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ lapack_int LAPACKE_dgesvdq( int matrix_layout, char joba, char jobp,
lapack_int lrwork = -1;
double* rwork = NULL;
double rwork_query;
lapack_int i;
if( matrix_layout != LAPACK_COL_MAJOR && matrix_layout != LAPACK_ROW_MAJOR ) {
LAPACKE_xerbla( "LAPACKE_dgesvdq", -1 );
return -1;
Expand Down
Loading