Amesos2 - Direct Sparse Solver Interfaces Version of the Day
Amesos2_cuSOLVER_FunctionMap.hpp
1// @HEADER
2//
3// ***********************************************************************
4//
5// Amesos2: Templated Direct Sparse Solver Package
6// Copyright 2011 Sandia Corporation
7//
8// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9// the U.S. Government retains certain rights in this software.
10//
11// Redistribution and use in source and binary forms, with or without
12// modification, are permitted provided that the following conditions are
13// met:
14//
15// 1. Redistributions of source code must retain the above copyright
16// notice, this list of conditions and the following disclaimer.
17//
18// 2. Redistributions in binary form must reproduce the above copyright
19// notice, this list of conditions and the following disclaimer in the
20// documentation and/or other materials provided with the distribution.
21//
22// 3. Neither the name of the Corporation nor the names of the
23// contributors may be used to endorse or promote products derived from
24// this software without specific prior written permission.
25//
26// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
27// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
29// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
30// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
31// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
32// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
33// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
34// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
35// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37//
38// Questions? Contact Michael A. Heroux (maherou@sandia.gov)
39//
40// ***********************************************************************
41//
42// @HEADER
43
44#ifndef AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
45#define AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
46
48#include "Amesos2_cuSOLVER_TypeMap.hpp"
49
50#include <cuda.h>
51#include <cusolverSp.h>
52#include <cusolverDn.h>
53#include <cusparse.h>
54#include <cusolverSp_LOWLEVEL_PREVIEW.h>
55
56#ifdef HAVE_TEUCHOS_COMPLEX
57#include <cuComplex.h>
58#endif
59
60namespace Amesos2 {
61
62 template <>
63 struct FunctionMap<cuSOLVER,double>
64 {
65 static cusolverStatus_t bufferInfo(
66 cusolverSpHandle_t handle,
67 int size,
68 int nnz,
69 cusparseMatDescr_t & desc,
70 const double * values,
71 const int * rowPtr,
72 const int * colIdx,
73 csrcholInfo_t & chol_info,
74 size_t * internalDataInBytes,
75 size_t * workspaceInBytes)
76 {
77 cusolverStatus_t status =
78 cusolverSpDcsrcholBufferInfo(handle, size, nnz, desc, values,
79 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
80 return status;
81 }
82
83 static cusolverStatus_t numeric(
84 cusolverSpHandle_t handle,
85 int size,
86 int nnz,
87 cusparseMatDescr_t & desc,
88 const double * values,
89 const int * rowPtr,
90 const int * colIdx,
91 csrcholInfo_t & chol_info,
92 void * buffer)
93 {
94 cusolverStatus_t status = cusolverSpDcsrcholFactor(
95 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
96 return status;
97 }
98
99 static cusolverStatus_t solve(
100 cusolverSpHandle_t handle,
101 int size,
102 const double * b,
103 double * x,
104 csrcholInfo_t & chol_info,
105 void * buffer)
106 {
107 cusolverStatus_t status = cusolverSpDcsrcholSolve(
108 handle, size, b, x, chol_info, buffer);
109 return status;
110 }
111 };
112
113 template <>
114 struct FunctionMap<cuSOLVER,float>
115 {
116 static cusolverStatus_t bufferInfo(
117 cusolverSpHandle_t handle,
118 int size,
119 int nnz,
120 cusparseMatDescr_t & desc,
121 const float * values,
122 const int * rowPtr,
123 const int * colIdx,
124 csrcholInfo_t & chol_info,
125 size_t * internalDataInBytes,
126 size_t * workspaceInBytes)
127 {
128 cusolverStatus_t status =
129 cusolverSpScsrcholBufferInfo(handle, size, nnz, desc, values,
130 rowPtr, colIdx, chol_info, internalDataInBytes, workspaceInBytes);
131 return status;
132 }
133
134 static cusolverStatus_t numeric(
135 cusolverSpHandle_t handle,
136 int size,
137 int nnz,
138 cusparseMatDescr_t & desc,
139 const float * values,
140 const int * rowPtr,
141 const int * colIdx,
142 csrcholInfo_t & chol_info,
143 void * buffer)
144 {
145 cusolverStatus_t status = cusolverSpScsrcholFactor(
146 handle, size, nnz, desc, values, rowPtr, colIdx, chol_info, buffer);
147 return status;
148 }
149
150 static cusolverStatus_t solve(
151 cusolverSpHandle_t handle,
152 int size,
153 const float * b,
154 float * x,
155 csrcholInfo_t & chol_info,
156 void * buffer)
157 {
158 cusolverStatus_t status = cusolverSpScsrcholSolve(
159 handle, size, b, x, chol_info, buffer);
160 return status;
161 }
162 };
163
164#ifdef HAVE_TEUCHOS_COMPLEX
165 template <>
166 struct FunctionMap<cuSOLVER,Kokkos::complex<double>>
167 {
168 static cusolverStatus_t bufferInfo(
169 cusolverSpHandle_t handle,
170 int size,
171 int nnz,
172 cusparseMatDescr_t & desc,
173 const void * values,
174 const int * rowPtr,
175 const int * colIdx,
176 csrcholInfo_t & chol_info,
177 size_t * internalDataInBytes,
178 size_t * workspaceInBytes)
179 {
180 typedef cuDoubleComplex scalar_t;
181 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
182 cusolverStatus_t status =
183 cusolverSpZcsrcholBufferInfo(handle, size, nnz, desc,
184 cu_values, rowPtr, colIdx, chol_info,
185 internalDataInBytes, workspaceInBytes);
186 return status;
187 }
188
189 static cusolverStatus_t numeric(
190 cusolverSpHandle_t handle,
191 int size,
192 int nnz,
193 cusparseMatDescr_t & desc,
194 const void * values,
195 const int * rowPtr,
196 const int * colIdx,
197 csrcholInfo_t & chol_info,
198 void * buffer)
199 {
200 typedef cuDoubleComplex scalar_t;
201 const scalar_t * cu_values =
202 reinterpret_cast<const scalar_t *>(values);
203 cusolverStatus_t status = cusolverSpZcsrcholFactor(
204 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
205 return status;
206 }
207
208 static cusolverStatus_t solve(
209 cusolverSpHandle_t handle,
210 int size,
211 const void * b,
212 void * x,
213 csrcholInfo_t & chol_info,
214 void * buffer)
215 {
216 typedef cuDoubleComplex scalar_t;
217 const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
218 scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
219 cusolverStatus_t status = cusolverSpZcsrcholSolve(
220 handle, size, cu_b, cu_x, chol_info, buffer);
221 return status;
222 }
223 };
224
225 template <>
226 struct FunctionMap<cuSOLVER,Kokkos::complex<float>>
227 {
228 static cusolverStatus_t bufferInfo(
229 cusolverSpHandle_t handle,
230 int size,
231 int nnz,
232 cusparseMatDescr_t & desc,
233 const void * values,
234 const int * rowPtr,
235 const int * colIdx,
236 csrcholInfo_t & chol_info,
237 size_t * internalDataInBytes,
238 size_t * workspaceInBytes)
239 {
240 typedef cuFloatComplex scalar_t;
241 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
242 cusolverStatus_t status =
243 cusolverSpCcsrcholBufferInfo(handle, size, nnz, desc,
244 cu_values, rowPtr, colIdx, chol_info,
245 internalDataInBytes, workspaceInBytes);
246 return status;
247 }
248
249 static cusolverStatus_t numeric(
250 cusolverSpHandle_t handle,
251 int size,
252 int nnz,
253 cusparseMatDescr_t & desc,
254 const void * values,
255 const int * rowPtr,
256 const int * colIdx,
257 csrcholInfo_t & chol_info,
258 void * buffer)
259 {
260 typedef cuFloatComplex scalar_t;
261 const scalar_t * cu_values = reinterpret_cast<const scalar_t *>(values);
262 cusolverStatus_t status = cusolverSpCcsrcholFactor(
263 handle, size, nnz, desc, cu_values, rowPtr, colIdx, chol_info, buffer);
264 return status;
265 }
266
267 static cusolverStatus_t solve(
268 cusolverSpHandle_t handle,
269 int size,
270 const void * b,
271 void * x,
272 csrcholInfo_t & chol_info,
273 void * buffer)
274 {
275 typedef cuFloatComplex scalar_t;
276 const scalar_t * cu_b = reinterpret_cast<const scalar_t *>(b);
277 scalar_t * cu_x = reinterpret_cast<scalar_t *>(x);
278 cusolverStatus_t status = cusolverSpCcsrcholSolve(
279 handle, size, cu_b, cu_x, chol_info, buffer);
280 return status;
281 }
282 };
283#endif
284
285} // end namespace Amesos2
286
287#endif // AMESOS2_CUSOLVER_FUNCTIONMAP_HPP
Declaration of Function mapping class for Amesos2.