hi,I working with jcuda in the example of matrix multiplication but i
have two cuda device in a single PC and i don’t now how to use the
second device in this example.
this is the code.
package org.jppf.example.prueba;
import java.util.Random;
import jcuda.*;
import jcuda.jcublas.JCublas;
public class TestMatrixMultiplication {
//final int size;
public TestMatrixMultiplication(final int newSize)
//public static void main(String args[])
{
//this.size = newSize;
testSgemm(newSize);
}
public static void testSgemm(int n)
{
float alpha = 1.0f;
float beta = 0.0f;
int nn = n * n;
System.out.println("Creando datos de entrada...");
float h_A[] = createRandomFloatData(nn);
float h_B[] = createRandomFloatData(nn);
float h_C[] = createRandomFloatData(nn);
float h_C_ref[] = h_C.clone();
System.out.println("Ejecutando la mutiplicacion de Matrices sequencial...");
long A=System.currentTimeMillis();
sgemmJava(n, alpha, h_A, h_B, beta, h_C_ref);
long B=System.currentTimeMillis()-A;
System.out.println("Tiempo de ejecución igual a: "+B+" milisegundos");
System.out.println("Ejecutando la multiplicacion de matrices con JCublas...");
A=System.currentTimeMillis();
sgemmJCublas(n, alpha, h_A, h_B, beta, h_C);
B=System.currentTimeMillis()-A;
System.out.println("Tiempo de ejecución igual a: "+B+" milisegundos");
boolean passed = isCorrectResult(h_C, h_C_ref);
System.out.println("testMatrixMultiplication: "+(passed?"PASSED":"FAILED"));
}
private static void sgemmJCublas(int n, float alpha, float A[], float B[],
float beta, float C[])
{
int nn = n * n;
// Initialize JCublas
JCublas.cublasInit();
// Allocate memory on the device
Pointer d_A = new Pointer();
Pointer d_B = new Pointer();
Pointer d_C = new Pointer();
JCublas.cublasAlloc(nn, Sizeof.FLOAT, d_A);
JCublas.cublasAlloc(nn, Sizeof.FLOAT, d_B);
JCublas.cublasAlloc(nn, Sizeof.FLOAT, d_C);
// Copy the memory from the host to the device
JCublas.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(A), 1, d_A, 1);
JCublas.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(B), 1, d_B, 1);
JCublas.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(C), 1, d_C, 1);
// Execute sgemm
JCublas.cublasSgemm(
'n', 'n', n, n, n, alpha, d_A, n, d_B, n, beta, d_C, n);
// Copy the result from the device to the host
JCublas.cublasGetVector(nn, Sizeof.FLOAT, d_C, 1, Pointer.to(C), 1);
// Clean up
JCublas.cublasFree(d_A);
JCublas.cublasFree(d_B);
JCublas.cublasFree(d_C);
JCublas.cublasShutdown();
}
private static void sgemmJava(int n, float alpha, float A[], float B[],
float beta, float C[])
{
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
float prod = 0;
for (int k = 0; k < n; ++k)
{
prod += A[k * n + i] * B[j * n + k];
}
C[j * n + i] = alpha * prod + beta * C[j * n + i];
}
}
}
private static float[] createRandomFloatData(int n)
{
Random random = new Random();
float x[] = new float[n];
for (int i = 0; i < n; i++)
{
x** = random.nextFloat();
}
return x;
}
private static boolean isCorrectResult(float result[], float reference[])
{
float errorNorm = 0;
float refNorm = 0;
for (int i = 0; i < result.length; ++i)
{
float diff = reference** - result**;
errorNorm += diff * diff;
refNorm += reference** * result**;
}
errorNorm = (float) Math.sqrt(errorNorm);
refNorm = (float) Math.sqrt(refNorm);
if (Math.abs(refNorm) < 1e-6)
{
return false;
}
return (errorNorm / refNorm < 1e-6f);
}
}