Skip to content
Snippets Groups Projects
Commit 925082f7 authored by ndoucet's avatar ndoucet
Browse files

add cusolver posv

parent 61b9aefa
Branches cublas
Tags
1 merge request!5add cusolver potrs
......@@ -24,6 +24,9 @@ namespace cusolver
template<typename T>
void potrf(const handle_t & handle, FillMode uplo, int N, T *A, int lda, T *workspace, int lwork, int *devInfo);
template<typename T>
void potrs(const handle_t & handle, FillMode uplo, int N, int nrhs, T *A, int lda, T *B, int ldb, int *devInfo);
template<typename T>
void getrf_bufferSize(const handle_t & handle, int m, int n, T *A, int lda, int *Lwork );
......
......@@ -21,6 +21,7 @@ namespace cusolver
struct CuSolver<float> {
constexpr static auto potrf_bufferSize = cusolverDnSpotrf_bufferSize;
constexpr static auto potrf = cusolverDnSpotrf;
constexpr static auto potrs = cusolverDnSpotrs;
constexpr static auto getrf_bufferSize = cusolverDnSgetrf_bufferSize;
constexpr static auto getrf = cusolverDnSgetrf;
constexpr static auto getrs = cusolverDnSgetrs;
......@@ -31,6 +32,7 @@ namespace cusolver
struct CuSolver<double> {
constexpr static auto potrf_bufferSize = cusolverDnDpotrf_bufferSize;
constexpr static auto potrf = cusolverDnDpotrf;
constexpr static auto potrs = cusolverDnDpotrs;
constexpr static auto getrf_bufferSize = cusolverDnDgetrf_bufferSize;
constexpr static auto getrf = cusolverDnDgetrf;
constexpr static auto getrs = cusolverDnDgetrs;
......@@ -41,6 +43,7 @@ namespace cusolver
struct CuSolver<cuComplex> {
constexpr static auto potrf_bufferSize = cusolverDnCpotrf_bufferSize;
constexpr static auto potrf = cusolverDnCpotrf;
constexpr static auto potrs = cusolverDnCpotrs;
constexpr static auto getrf_bufferSize = cusolverDnCgetrf_bufferSize;
constexpr static auto getrf = cusolverDnCgetrf;
constexpr static auto getrs = cusolverDnCgetrs;
......@@ -51,6 +54,7 @@ namespace cusolver
struct CuSolver<cuDoubleComplex> {
constexpr static auto potrf_bufferSize = cusolverDnZpotrf_bufferSize;
constexpr static auto potrf = cusolverDnZpotrf;
constexpr static auto potrs = cusolverDnZpotrs;
constexpr static auto getrf_bufferSize = cusolverDnZgetrf_bufferSize;
constexpr static auto getrf = cusolverDnZgetrf;
constexpr static auto getrs = cusolverDnZgetrs;
......@@ -77,6 +81,16 @@ namespace cusolver
template void potrf<cuComplex> (const handle_t & handle, FillMode uplo, int N, cuComplex *A, int lda, cuComplex *workspace, int lwork, int *devInfo);
template void potrf<cuDoubleComplex>(const handle_t & handle, FillMode uplo, int N, cuDoubleComplex *A, int lda, cuDoubleComplex *workspace, int lwork, int *devInfo);
template<typename T>
void potrs(const handle_t & handle, FillMode uplo, int N, int nrhs, T *A, int lda, T *B, int ldb, int *devInfo) {
throw_if_error(CuSolver<T>::potrs(handle.id(), convert(uplo), N, nrhs, A, lda, B, ldb, devInfo));
}
template void potrs<float> (const handle_t & handle, FillMode uplo, int N, int nrhs, float *A, int lda, float *B, int ldb, int *devInfo);
template void potrs<double> (const handle_t & handle, FillMode uplo, int N, int nrhs, double *A, int lda, double *B, int ldb, int *devInfo);
template void potrs<cuComplex> (const handle_t & handle, FillMode uplo, int N, int nrhs, cuComplex *A, int lda, cuComplex *B, int ldb, int *devInfo);
template void potrs<cuDoubleComplex>(const handle_t & handle, FillMode uplo, int N, int nrhs, cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb, int *devInfo);
template<typename T>
void getrf_bufferSize(const handle_t & handle, int m, int n, T *A, int lda, int *Lwork ){
throw_if_error(CuSolver<T>::getrf_bufferSize(handle.id(), m, n, A, lda, Lwork));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment