cublasSgemmBatched() usage


#1

Greetings,
I’ve been trying to use cublasSgemmBatched() function in jcuda for matrix multiplication and I’m
not sure how to properly handle pointer passing and vectors of batched matrices. I will be really
thankful if someone knows how to modify my code to properly handle this problem.

public static void SsmmBatchJCublas(int m, int n, int k, float A[], float B[]){

		// Create a CUBLAS handle
		cublasHandle handle = new cublasHandle();
		cublasCreate(handle);

		// Allocate memory on the device
		Pointer d_A = new Pointer();
		Pointer d_B = new Pointer();
		Pointer d_C = new Pointer();

		
		cudaMalloc(d_A, m*k * Sizeof.FLOAT);
		cudaMalloc(d_B, n*k * Sizeof.FLOAT);
		cudaMalloc(d_C, m*n * Sizeof.FLOAT);

		float[] C = new float[m*n];
		// Copy the memory from the host to the device
		cublasSetVector(m*k, Sizeof.FLOAT, Pointer.to(A), 1, d_A, 1);
		cublasSetVector(n*k, Sizeof.FLOAT, Pointer.to(B), 1, d_B, 1);
		cublasSetVector(m*n, Sizeof.FLOAT, Pointer.to(C), 1, d_C, 1);
		
		Pointer[] Aarray = new Pointer[]{d_A};
		Pointer AarrayPtr = Pointer.to(Aarray);
		Pointer[] Barray = new Pointer[]{d_B};
		Pointer BarrayPtr = Pointer.to(Barray);
		Pointer[] Carray = new Pointer[]{d_C};
		Pointer CarrayPtr = Pointer.to(Carray);

		// Execute sgemm
		Pointer pAlpha = Pointer.to(new float[]{1});
		Pointer pBeta = Pointer.to(new float[]{0});


		cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, pAlpha, AarrayPtr, Aarray.length, BarrayPtr, Barray.length, pBeta, CarrayPtr, Carray.length, Aarray.length);
		// Copy the result from the device to the host
		cublasGetVector(m*n, Sizeof.FLOAT, d_C, 1, Pointer.to(C), 1);

		// Clean up
		cudaFree(d_A);
		cudaFree(d_B);
		cudaFree(d_C);
		cublasDestroy(handle);
	}

#2

Hello

When the method was introduced I wrote a quick test, but it’s a bit messy, and not extensively verified.

I’ll add the task to my “todo” list to verify and clean it up, and extend it in order to become a sample.

However, here’s the code, maybe it already helps you a little

package tests.jcublas;

import jcuda.*;
import jcuda.jcublas.*;
import jcuda.runtime.*;
import utils.Utils;

class JCublas2TestSgemmBatched
{
    public static void main(String[] args)
    {
        JCublas2.setExceptionsEnabled(true);
        JCuda.setExceptionsEnabled(true);
        testSgemmBatched(10, 500);
    }
    public static boolean testSgemmBatched(int b, int n)
    {
        System.out.println("=== Testing Sgemm with "+b+" batches of size " + n + " ===");

        float alpha = 0.3f;
        float beta = 0.7f;
        int nn = n * n;

        // System.out.println("Creating input data...");
        
        float h_A[][] = new float[b][];
        float h_B[][] = new float[b][];
        float h_C[][] = new float[b][];
        float h_C_ref[][] = new float[b][];
        for (int i=0; i<b; i++)
        {
            h_A[i] = Utils.createRandomFloatData1D(nn);
            h_B[i] = Utils.createRandomFloatData1D(nn);
            h_C[i] = Utils.createRandomFloatData1D(nn);
            h_C_ref[i] = h_C[i].clone();
        }

        System.out.println("Performing Sgemm with Java...");
        sgemmJava(n, alpha, h_A, h_B, beta, h_C_ref);

        System.out.println("Performing Sgemm with JCublas2...");
        sgemmBatchesJCublas2(n, alpha, h_A, h_B, beta, h_C);

        // Print the test results
        boolean passed = true;
        for (int i=0; i<b; i++)
        {
            passed &= Utils.equalNorm1D(h_C[i], h_C_ref[i]);
        }
        System.out.println(String.format("testSgemm %s", 
            passed ? "PASSED" : "FAILED"));
        return passed;
    }

    static void sgemmBatchesJCublas2(int n, float alpha, float h_A[][],
                    float h_B[][], float beta, float h_C[][])
    {
        //JCublas2.setLogLevel(LogLevel.LOG_DEBUGTRACE);
        
        int nn = n * n;
        int b = h_A.length;
        Pointer[] h_Aarray = new Pointer[b];
        Pointer[] h_Barray = new Pointer[b];
        Pointer[] h_Carray = new Pointer[b];
        for (int i=0; i<b; i++)
        {
            h_Aarray[i] = new Pointer();
            h_Barray[i] = new Pointer();
            h_Carray[i] = new Pointer();
            JCuda.cudaMalloc(h_Aarray[i], nn * Sizeof.FLOAT);
            JCuda.cudaMalloc(h_Barray[i], nn * Sizeof.FLOAT);
            JCuda.cudaMalloc(h_Carray[i], nn * Sizeof.FLOAT);
            JCublas2.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(h_A[i]), 1, h_Aarray[i], 1);
            JCublas2.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(h_B[i]), 1, h_Barray[i], 1);
            JCublas2.cublasSetVector(nn, Sizeof.FLOAT, Pointer.to(h_C[i]), 1, h_Carray[i], 1);
        }
        Pointer d_Aarray = new Pointer();
        Pointer d_Barray = new Pointer();
        Pointer d_Carray = new Pointer();
        JCuda.cudaMalloc(d_Aarray, b * Sizeof.POINTER);
        JCuda.cudaMalloc(d_Barray, b * Sizeof.POINTER);
        JCuda.cudaMalloc(d_Carray, b * Sizeof.POINTER);
        JCuda.cudaMemcpy(d_Aarray, Pointer.to(h_Aarray), b * Sizeof.POINTER, cudaMemcpyKind.cudaMemcpyHostToDevice);
        JCuda.cudaMemcpy(d_Barray, Pointer.to(h_Barray), b * Sizeof.POINTER, cudaMemcpyKind.cudaMemcpyHostToDevice);
        JCuda.cudaMemcpy(d_Carray, Pointer.to(h_Carray), b * Sizeof.POINTER, cudaMemcpyKind.cudaMemcpyHostToDevice);
        
        cublasHandle handle = new cublasHandle();
        JCublas2.cublasCreate(handle);

        JCublas2.cublasSgemmBatched(
            handle, 
            cublasOperation.CUBLAS_OP_N, 
            cublasOperation.CUBLAS_OP_N, 
            n, n, n, 
            Pointer.to(new float[]{alpha}), 
            d_Aarray, n, d_Barray, n, 
            Pointer.to(new float[]{beta}), 
            d_Carray, n, b);

        for (int i=0; i<b; i++)
        {
            JCublas2.cublasGetVector(nn, Sizeof.FLOAT, h_Carray[i], 1, Pointer.to(h_C[i]), 1);
            JCuda.cudaFree(h_Aarray[i]);
            JCuda.cudaFree(h_Barray[i]);
            JCuda.cudaFree(h_Carray[i]);
        }
        JCuda.cudaFree(d_Aarray);
        JCuda.cudaFree(d_Barray);
        JCuda.cudaFree(d_Carray);
        JCublas2.cublasDestroy(handle);
        
    }

    static void sgemmJava(int n, float alpha, float A[][], float B[][], float beta, float C[][])
    {
        for (int i=0; i<A.length; i++)
        {
            sgemmJava(n, alpha, A[i], B[i], beta, C[i]);
        }
    }
    
    static void sgemmJava(int n, float alpha, float A[], float B[], float beta, float C[])
    {
        int i;
        int j;
        int k;
        for (i = 0; i < n; ++i)
        {
            for (j = 0; j < n; ++j)
            {
                float prod = 0;
                for (k = 0; k < n; ++k)
                {
                    prod += A[k * n + i] * B[j * n + k];
                }
                C[j * n + i] = alpha * prod + beta * C[j * n + i];
            }
        }
    }
}

The methods from the ‘Utils’ class are straightforward

    public static boolean equalNorm1D(float a[], float b[])
    {
        return equalNorm1D(a, b, a.length);
    }
    
    public static boolean equalNorm1D(float a[], float b[], int n)
    {
        if (a.length < n || b.length < n)
        {
            return false;
        }
        float errorNorm = 0;
        float refNorm = 0;
        for (int i = 0; i < n; i++) 
        {
            float diff = a[i] - b[i];
            errorNorm += diff * diff;
            refNorm += a[i] * a[i];
        }
        errorNorm = (float)Math.sqrt(errorNorm);
        refNorm = (float)Math.sqrt(refNorm);
        return (errorNorm / refNorm < 1e-6f);
    }

    public static float[] createRandomFloatData1D(int x)
    {
        float a[] = new float[x];
        for (int i=0; i<x; i++)
        {
            a[i] = random.nextFloat();
        }
        return a;
    }

(This class is used in some internal tests, but similar methods are available in the classes from the Utilities package at http://jcuda.org/utilities/utilities.html )

Hope that helps, otherwise I’ll try to schedule the creation of the sample somewhere between … the update for CUDA 4.2, the update for CUDA 5.0, the update for OpenCL 1.2, and … the other tasks :wink:

bye
Marco