Hello
When the method was introduced I wrote a quick test, but it’s a bit messy, and not extensively verified.
I’ll add the task to my „todo“ list to verify and clean it up, and extend it in order to become a sample.
However, here’s the code, maybe it already helps you a little
package tests.jcublas;
import jcuda.*;
import jcuda.jcublas.*;
import jcuda.runtime.*;
import utils.Utils;
class JCublas2TestSgemmBatched
{
public static void main(String[] args)
{
JCublas2.setExceptionsEnabled(true);
JCuda.setExceptionsEnabled(true);
testSgemmBatched(10, 500);
}
public static boolean testSgemmBatched(int b, int n)
{
System.out.println("=== Testing Sgemm with "+b+" batches of size " + n + " ===");
float alpha = 0.3f;
float beta = 0.7f;
int nn = n * n;
// System.out.println("Creating input data...");
float h_A[][] = new float**[];
float h_B[][] = new float**[];
float h_C[][] = new float**[];
float h_C_ref[][] = new float**[];
for (int i=0; i<b; i++)
{
h_A** = Utils.createRandomFloatData1D(nn);
h_B** = Utils.createRandomFloatData1D(nn);
h_C** = Utils.createRandomFloatData1D(nn);
h_C_ref** = h_C**.clone();
}
System.out.println("Performing Sgemm with Java...");
sgemmJava(n, alpha, h_A, h_B, beta, h_C_ref);
System.out.println("Performing Sgemm with JCublas2...");
sgemmBatchesJCublas2(n, alpha, h_A, h_B, beta, h_C);
// Print the test results
boolean passed = true;
for (int i=0; i<b; i++)
{
passed &= Utils.equalNorm1D(h_C**, h_C_ref**);
}
System.out.println(String.format("testSgemm %s",
passed ? "PASSED" : "FAILED"));
return passed;
}
static void sgemmBatchesJCublas2(int n, float alpha, float h_A[][],
float h_B[][], float beta, float h_C[][])
{
//JCublas2.setLogLevel(LogLevel.LOG_DEBUGTRACE);
int nn = n * n;
int b = h_A.length;
Pointer[] h_Aarray = new Pointer**;
Pointer[] h_Barray = new Pointer**;
Pointer[] h_Carray = new Pointer**;
for (int i=0; i<b; i++)
{
h_Aarray** = new Pointer();
h_Barray** = new Pointer();
h_Carray** = new Pointer();
JCuda.cudaMalloc(h_Aarray**, nn * Sizeof.FLOAT);
JCuda.cudaMalloc(h_Barray**, nn * Sizeof.FLOAT);
JCuda.cudaMalloc(h_Carray**, nn * Sizeof.FLOAT);
JCublas2.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(h_A**), 1, h_Aarray**, 1);
JCublas2.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(h_B**), 1, h_Barray**, 1);
JCublas2.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(h_C**), 1, h_Carray**, 1);
}
Pointer d_Aarray = new Pointer();
Pointer d_Barray = new Pointer();
Pointer d_Carray = new Pointer();
JCuda.cudaMalloc(d_Aarray, b * Sizeof.POINTER);
JCuda.cudaMalloc(d_Barray, b * Sizeof.POINTER);
JCuda.cudaMalloc(d_Carray, b * Sizeof.POINTER);
JCuda.cudaMemcpy(d_Aarray, Pointer.to(h_Aarray), b * Sizeof.POINTER, cudaMemcpyKind.cudaMemcpyHostToDevice);
JCuda.cudaMemcpy(d_Barray, Pointer.to(h_Barray), b * Sizeof.POINTER, cudaMemcpyKind.cudaMemcpyHostToDevice);
JCuda.cudaMemcpy(d_Carray, Pointer.to(h_Carray), b * Sizeof.POINTER, cudaMemcpyKind.cudaMemcpyHostToDevice);
cublasHandle handle = new cublasHandle();
JCublas2.cublasCreate(handle);
JCublas2.cublasSgemmBatched(
handle,
cublasOperation.CUBLAS_OP_N,
cublasOperation.CUBLAS_OP_N,
n, n, n,
Pointer.to(new float[]{alpha}),
d_Aarray, n, d_Barray, n,
Pointer.to(new float[]{beta}),
d_Carray, n, b);
for (int i=0; i<b; i++)
{
JCublas2.cublasGetVector(nn, Sizeof.FLOAT, h_Carray**, 1, Pointer.to(h_C**), 1);
JCuda.cudaFree(h_Aarray**);
JCuda.cudaFree(h_Barray**);
JCuda.cudaFree(h_Carray**);
}
JCuda.cudaFree(d_Aarray);
JCuda.cudaFree(d_Barray);
JCuda.cudaFree(d_Carray);
JCublas2.cublasDestroy(handle);
}
static void sgemmJava(int n, float alpha, float A[][], float B[][], float beta, float C[][])
{
for (int i=0; i<A.length; i++)
{
sgemmJava(n, alpha, A**, B**, beta, C**);
}
}
static void sgemmJava(int n, float alpha, float A[], float B[], float beta, float C[])
{
int i;
int j;
int k;
for (i = 0; i < n; ++i)
{
for (j = 0; j < n; ++j)
{
float prod = 0;
for (k = 0; k < n; ++k)
{
prod += A[k * n + i] * B[j * n + k];
}
C[j * n + i] = alpha * prod + beta * C[j * n + i];
}
}
}
}
The methods from the ‚Utils‘ class are straightforward
public static boolean equalNorm1D(float a[], float b[])
{
return equalNorm1D(a, b, a.length);
}
public static boolean equalNorm1D(float a[], float b[], int n)
{
if (a.length < n || b.length < n)
{
return false;
}
float errorNorm = 0;
float refNorm = 0;
for (int i = 0; i < n; i++)
{
float diff = a** - b**;
errorNorm += diff * diff;
refNorm += a** * a**;
}
errorNorm = (float)Math.sqrt(errorNorm);
refNorm = (float)Math.sqrt(refNorm);
return (errorNorm / refNorm < 1e-6f);
}
public static float[] createRandomFloatData1D(int x)
{
float a[] = new float[x];
for (int i=0; i<x; i++)
{
a** = random.nextFloat();
}
return a;
}
(This class is used in some internal tests, but similar methods are available in the classes from the Utilities package at jcuda.org - Utilities )
Hope that helps, otherwise I’ll try to schedule the creation of the sample somewhere between … the update for CUDA 4.2, the update for CUDA 5.0, the update for OpenCL 1.2, and … the other tasks
bye
Marco