It’s really difficult. It would be a pity if you spent your (and my) time for ~„just doing something on the GPU with CUDA, because that’s ‚research‘ and you can get a PhD for that“, without making sure that the result makes any sense whatsoever.
I tried to explain extensively that you really have to think carefully about how you want to structure your data, how to send it to the GPU, and what you are actually going to do with that. I showed implementation options and tried to explain the possible pros and cons, including the crucial question about the string lengths.
Now, to put it that way: Trying to tokenize a string on the GPU by searching for the ","
comma does not make sense, at all. You have to be aware of the fact that the kernel is executed by 1000s of threads, in parallel, at the same time. Searching for a ","
is nothing that you can sensibly do there. That’s simply not what CUDA is intended for. Besides, you already have this information on Java side, so you can just pass it to the GPU as it is.
Here is another example:
It takes the List<String>
as the input, and the kernel writes the last letter of each of these strings into the output memory.
To accomplish that, the program joins
the Strings into a single string, which is then passed to the GPU. It also stores the end indices of the individual strings inside the combined string, as an int[]
array, and also passes this to the GPU. The „end indices“ are basically the accumulated string lengths. This is *only intended for this example: Yes, this does not make sense in general. No, you should not just copy and paste this and try to solve your task with that. You should think about whether you need…
- The string lengths
- The string start indices
- The string end indices
- The accumulated start indices
- The accumulated end indices
- Some combination of these
and you should think about this carefully, depending on what you want to do with these strings in the kernel, and then try to implement it.
Maybe the example is helpful in any way, nevertheless:
import static jcuda.driver.JCudaDriver.cuCtxCreate;
import static jcuda.driver.JCudaDriver.cuCtxSynchronize;
import static jcuda.driver.JCudaDriver.cuDeviceGet;
import static jcuda.driver.JCudaDriver.cuInit;
import static jcuda.driver.JCudaDriver.cuLaunchKernel;
import static jcuda.driver.JCudaDriver.cuMemAlloc;
import static jcuda.driver.JCudaDriver.cuMemcpyDtoH;
import static jcuda.driver.JCudaDriver.cuMemcpyHtoD;
import static jcuda.driver.JCudaDriver.cuModuleGetFunction;
import static jcuda.driver.JCudaDriver.cuModuleLoadData;
import static jcuda.nvrtc.JNvrtc.nvrtcCompileProgram;
import static jcuda.nvrtc.JNvrtc.nvrtcCreateProgram;
import static jcuda.nvrtc.JNvrtc.nvrtcDestroyProgram;
import static jcuda.nvrtc.JNvrtc.nvrtcGetPTX;
import static jcuda.nvrtc.JNvrtc.nvrtcGetProgramLog;
import java.util.Arrays;
import java.util.List;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;
import jcuda.nvrtc.JNvrtc;
import jcuda.nvrtc.nvrtcProgram;
/**
*/
public class JCudaStringsWithLengths
{
private static String programSourceCode =
"extern \"C\"" + "\n" +
"__global__ void copy(int numStrings, char *srcStringsJoined, int *srcStringEndIndices, char *lastLetters)" + "\n" +
"{" + "\n" +
" int i = blockIdx.x * blockDim.x + threadIdx.x;" + "\n" +
" if (i<numStrings)" + "\n" +
" {" + "\n" +
" int endIndex = srcStringEndIndices[i];" + "\n" +
" char lastLetter = srcStringsJoined[endIndex-1];" + "\n" +
" lastLetters[i] = lastLetter;" + "\n" +
" }" + "\n" +
"}" + "\n";
public static void main(String[] args)
{
// Default initialization
JCudaDriver.setExceptionsEnabled(true);
JNvrtc.setExceptionsEnabled(true);
cuInit(0);
CUdevice device = new CUdevice();
cuDeviceGet(device, 0);
CUcontext context = new CUcontext();
cuCtxCreate(context, 0, device);
nvrtcProgram program = new nvrtcProgram();
nvrtcCreateProgram(
program, programSourceCode, null, 0, null, null);
nvrtcCompileProgram(program, 0, null);
String programLog[] = new String[1];
nvrtcGetProgramLog(program, programLog);
System.out.println("Program compilation log:\n" + programLog[0]);
String[] ptx = new String[1];
nvrtcGetPTX(program, ptx);
nvrtcDestroyProgram(program);
CUmodule module = new CUmodule();
cuModuleLoadData(module, ptx[0]);
CUfunction function = new CUfunction();
cuModuleGetFunction(function, module, "copy");
// Define input data and copy it to the host:
// The input is a list of strings. It will be concatenated
// into ONE string:
List<String> srcStrings = Arrays.asList(
"This", "will", "be", "horribly", "inefficient");
int numStrings = srcStrings.size();
String srcStringsJoined = String.join("", srcStrings);
byte[] srcStringsJoinedData = srcStringsJoined.getBytes();
int totalLength = srcStringsJoinedData.length;
// The endIndex of each string is stored in an array:
int srcStringEndIndices[] = new int[srcStrings.size()];
int endIndex = 0;
for (int i=0; i<numStrings; i++)
{
endIndex += srcStrings.get(i).length();
srcStringEndIndices[i] = endIndex;
}
// Copy the string data to the device
CUdeviceptr srcStringsJoinedPointer = new CUdeviceptr();
cuMemAlloc(srcStringsJoinedPointer, totalLength * Sizeof.BYTE);
cuMemcpyHtoD(srcStringsJoinedPointer,
Pointer.to(srcStringsJoinedData), totalLength * Sizeof.BYTE);
// Copy the string end indices to the device
CUdeviceptr srcStringEndIndicesPointer = new CUdeviceptr();
cuMemAlloc(srcStringEndIndicesPointer, numStrings * Sizeof.INT);
cuMemcpyHtoD(srcStringEndIndicesPointer,
Pointer.to(srcStringEndIndices), numStrings * Sizeof.INT);
// Allocate output memory
CUdeviceptr dstDevice = new CUdeviceptr();
cuMemAlloc(dstDevice, numStrings * Sizeof.BYTE);
// Run the kernel
Pointer kernelParameters = Pointer.to(
Pointer.to(new int[]{numStrings}),
Pointer.to(srcStringsJoinedPointer),
Pointer.to(srcStringEndIndicesPointer),
Pointer.to(dstDevice)
);
int blockSizeX = 256;
int gridSizeX = (numStrings + blockSizeX - 1) / blockSizeX;
cuLaunchKernel(function,
gridSizeX, 1, 1,
blockSizeX, 1, 1,
0, null,
kernelParameters, null
);
cuCtxSynchronize();
// Copy output data to the host
byte lastLetters[] = new byte[numStrings];
cuMemcpyDtoH(Pointer.to(lastLetters), dstDevice,
numStrings * Sizeof.BYTE);
// Print the result
cuCtxSynchronize();
System.out.println("Last letters: " + new String(lastLetters));
// Clean up omitted here
}
}