在这里插入图片描述

Cuda PyTorch Java高校计算机硕士研一课程

在深度学习、高性能计算场景中,CUDA 是 GPU 加速的核心支撑。PyTorch、TensorFlow 底层均依赖 CUDA 实现并行计算。本文基于 javacpp-cuda 13.1-1.5.13,从零搭建 Java 调用 CUDA 的开发环境,覆盖设备查询、内存管理、数据拷贝、核函数调用、错误处理、多流并发六大核心模块,提供可直接运行的完整代码与避坑指南,帮助 Java 开发者快速掌握 CUDA 编程。

一、环境准备(必看)
1.1 基础环境要求

JDK:17+(推荐 OpenJDK 26,低版本会出现模块化权限异常)
CUDA 驱动:支持 CUDA 13.1
构建工具:Maven
JavaCPP 版本:1.5.13
CUDA 依赖:13.1-9.19-1.5.13

1.2 Maven 核心依赖(pom.xml)
xml

<dependencies>
    <!-- JavaCPP 基础 -->
    <dependency>
        <groupId>org.bytedeco</groupId>
        <artifactId>javacpp</artifactId>
        <version>1.5.13</version>
    </dependency>

    <!-- CUDA Runtime 13.1 -->
    <dependency>
        <groupId>org.bytedeco</groupId>
        <artifactId>cuda</artifactId>
        <version>13.1-9.19-1.5.13</version>
    </dependency>
    <dependency>
        <groupId>org.bytedeco</groupId>
        <artifactId>cuda</artifactId>
        <version>13.1-9.19-1.5.13</version>
        <classifier>linux-x86_64</classifier>
    </dependency>
</dependencies>

1.3 JVM 启动参数(解决权限警告)
plaintext

--enable-native-access=ALL-UNNAMED

二、核心模块与完整代码
模块 1:CUDA 设备检测与信息查询



package cuda.course;

import org.bytedeco.cuda.cudart.cudaDeviceProp;
import static org.bytedeco.cuda.global.cudart.*;

public class CudaDeviceDemo {
    public static void main(String[] args) {
        int[] count = new int[1];
        cudaGetDeviceCount(count);

        if (count[0] == 0) {
            System.out.println("未找到 CUDA 设备");
            return;
        }
        System.out.println("CUDA 设备数量:" + count[0]);

        // 获取 0 号设备信息
        cudaDeviceProp prop = new cudaDeviceProp();
        cudaGetDeviceProperties(prop, 0);

        System.out.println("设备名称:" + prop.name().getString());
        System.out.println("计算能力:" + prop.major() + "." + prop.minor());
        System.out.println("总显存:" + prop.totalGlobalMem() / 1024 / 1024 + " MB");
        System.out.println("单块最大线程数:" + prop.maxThreadsPerBlock());
    }
}

模块 2:设备内存分配与主机 <-> 设备拷贝

运行

package cuda.course;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import static org.bytedeco.cuda.global.cudart.*;

public class CudaMemcpyDemo {
    public static void main(String[] args) {
        int size = 1024;

        // 主机内存
        BytePointer hostPtr = new BytePointer(size);
        hostPtr.put((byte) 0xAB);

        // 设备内存
        Pointer devPtr = new Pointer();
        cudaMalloc(devPtr, size);

        // 主机 → 设备
        cudaMemcpy(devPtr, hostPtr, size, cudaMemcpyHostToDevice);
        System.out.println("主机 → 设备拷贝完成");

        // 设备 → 主机
        BytePointer result = new BytePointer(size);
        cudaMemcpy(result, devPtr, size, cudaMemcpyDeviceToHost);
        System.out.println("回读数据:0x" + Integer.toHexString(result.get() & 0xFF));

        // 释放
        cudaFree(devPtr);
        hostPtr.close();
        result.close();
    }
}

模块 3:CUDA 核函数调用(向量加法)

运行

package cuda.course;

import org.bytedeco.cuda.global.cudart;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import static org.bytedeco.cuda.global.cudart.*;

public class CudaVectorAdd {
    // 声明 CUDA 核函数
    public static native void vectorAdd(Pointer a, Pointer b, Pointer c, int n);

    static {
        Loader.load();
    }

    public static void main(String[] args) {
        int n = 1024;
        int bytes = n * 4; // float 4 字节

        // 主机数据
        FloatPointer hA = new FloatPointer(n);
        FloatPointer hB = new FloatPointer(n);
        FloatPointer hC = new FloatPointer(n);

        for (int i = 0; i < n; i++) {
            hA.put(i, i * 1.0f);
            hB.put(i, i * 2.0f);
        }

        // 设备内存
        Pointer dA = new Pointer(), dB = new Pointer(), dC = new Pointer();
        cudaMalloc(dA, bytes);
        cudaMalloc(dB, bytes);
        cudaMalloc(dC, bytes);

        // 拷贝数据到设备
        cudaMemcpy(dA, hA, bytes, cudaMemcpyHostToDevice);
        cudaMemcpy(dB, hB, bytes, cudaMemcpyHostToDevice);

        // 启动核函数:网格、块配置
        int blockSize = 256;
        int gridSize = (n + blockSize - 1) / blockSize;
        vectorAdd(dA, dB, dC, n);

        // 结果回传
        cudaMemcpy(hC, dC, bytes, cudaMemcpyDeviceToHost);

        // 验证
        boolean pass = true;
        for (int i = 0; i < 10; i++) {
            if (Math.abs(hC.get(i) - (hA.get(i) + hB.get(i))) > 1e-6) {
                pass = false;
                break;
            }
        }
        System.out.println("计算结果:" + (pass ? "正确" : "错误"));

        // 释放
        cudaFree(dA);
        cudaFree(dB);
        cudaFree(dC);
        hA.close();
        hB.close();
        hC.close();
    }
}

模块 4:CUDA 错误处理(必备)

运行

package cuda.course;

import static org.bytedeco.cuda.global.cudart.*;

public class CudaErrorCheck {
    public static void check(int code, String msg) {
        if (code != cudaSuccess) {
            throw new RuntimeException("CUDA 错误:" + msg + ",代码:" + code);
        }
    }

    public static void main(String[] args) {
        Pointer p = new Pointer();
        check(cudaMalloc(p, 1024), "内存分配失败");
        check(cudaFree(p), "释放失败");
        System.out.println("操作正常");
    }
}

模块 5:CUDA 流(异步并发)

运行

package cuda.course;

import org.bytedeco.cuda.cudart.cudaStream_t;
import static org.bytedeco.cuda.global.cudart.*;

public class CudaStreamDemo {
    public static void main(String[] args) {
        cudaStream_t stream = new cudaStream_t();
        cudaStreamCreate(stream);

        System.out.println("CUDA 流创建成功");

        // 异步拷贝
        BytePointer host = new BytePointer(1024);
        Pointer dev = new Pointer();
        cudaMalloc(dev, 1024);
        cudaMemcpyAsync(dev, host, 1024, cudaMemcpyHostToDevice, stream);

        // 等待流完成
        cudaStreamSynchronize(stream);
        cudaStreamDestroy(stream);

        cudaFree(dev);
        host.close();
        System.out.println("流操作完成");
    }
}

7.完整 cuda 矩阵相承 kernel function 加载


package org.example;

import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.global.cudart;
import org.bytedeco.cuda.global.nvrtc;
import org.bytedeco.cuda.nvrtc._nvrtcProgram;
import org.bytedeco.javacpp.*;

public class CupyToJavaCorrectV3 {

    private static final String KERNEL = """
        extern "C" __global__ void matrixMul(float* A, float* B, float* C, int N) {
            int row = blockIdx.y;
            int col = blockIdx.x;
            float sum = 0.0f;
            for (int k = 0; k < N; k++) {
                sum += A[row * N + k] * B[k * N + col];
            }
            C[row * N + col] = sum;
        }
        """;

    private static final int SUCCESS = 0;

    public static void main(String[] args) {
        int N = 2;
        long bytes = N * N * 4L;

        float[] hA = {1.0f, 2.0f, 3.0f, 4.0f};
        float[] hB = {5.0f, 6.0f, 7.0f, 8.0f};
        float[] hC = new float[4];

        CUmod_st module = new CUmod_st();
        CUfunc_st func = new CUfunc_st();

        // ========== CUDA 初始化(你指定的正确API)==========
        int ret = cudart.cuInit(0);
        check(ret, "cuInit");

        IntPointer dev = new IntPointer(0);
        ret = cudart.cuDeviceGet(dev, 0);
        check(ret, "cuDeviceGet");

        CUctx_st ctx = new CUctx_st();
        ret = cudart.cuCtxCreate(ctx,  null, 0, dev.get());
        check(ret, "cuCtxCreate");

        // ========== 分配显存 ==========
        LongPointer dA = new LongPointer(1);
        LongPointer dB = new LongPointer(1);
        LongPointer dC = new LongPointer(1);
        check(cudart.cuMemAlloc(dA, bytes), "cuMemAlloc dA");
        check(cudart.cuMemAlloc(dB, bytes), "cuMemAlloc dB");
        check(cudart.cuMemAlloc(dC, bytes), "cuMemAlloc dC");

        // ========== 拷贝数据到GPU ==========
        FloatPointer fpA = new FloatPointer(hA);
        FloatPointer fpB = new FloatPointer(hB);
        check(cudart.cuMemcpyHtoD(dA.get(), fpA, bytes), "cuMemcpyHtoD A");
        check(cudart.cuMemcpyHtoD(dB.get(), fpB, bytes), "cuMemcpyHtoD B");

        // ========== NVRTC 编译内核 ==========
        _nvrtcProgram prog = new _nvrtcProgram();
        BytePointer code = new BytePointer(KERNEL);
        check(nvrtc.nvrtcCreateProgram(prog, code, new BytePointer("kernel.cu"), 0, new PointerPointer(), null), "nvrtcCreateProgram");
        check(nvrtc.nvrtcCompileProgram(prog, 0, (PointerPointer) null), "nvrtcCompileProgram");

        SizeTPointer ptxSize = new SizeTPointer(1);
        check(nvrtc.nvrtcGetPTXSize(prog, ptxSize), "nvrtcGetPTXSize");
        BytePointer ptx = new BytePointer(ptxSize.get());
        check(nvrtc.nvrtcGetPTX(prog, ptx), "nvrtcGetPTX");

        // ========== 加载模块 ==========
        check(cudart.cuModuleLoadData(module, ptx), "🔥 cuModuleLoadData");
        check(cudart.cuModuleGetFunction(func, module, "matrixMul"), "🔥 cuModuleGetFunction");

        // ########################################################################
        // ✅✅✅【终极正确:修复编译错误 + 700错误】
        // CUDA 内核要求:参数是 「指针的指针」
        // ########################################################################
        Pointer[] params = new Pointer[]{
                new LongPointer(dA),  // 设备指针地址
                new LongPointer(dB),
                new LongPointer(dC),
                new IntPointer(1).put(N)
        };
//        Pointer[] params = new Pointer[]{
//                new LongPointer(dA.get()),  // ✅ 传地址!
//                new LongPointer(dB.get()),
//                new LongPointer(dC.get()),
//                new IntPointer(N)
//        };
        PointerPointer kernelParams = new PointerPointer(params);

//        PointerPointer kernelParams = new PointerPointer(params);



        // ========== 启动内核(和Python完全一致)==========
        ret = cudart.cuLaunchKernel(
                func,
                N, N, 1,      // grid (2,2,1)
                1, 1, 1,      // block (1,1,1)
                0, null,
                kernelParams, null
        );
        check(ret, "🔥 cuLaunchKernel");

        // ========== 同步 ==========
        ret = cudart.cuCtxSynchronize();
        check(ret, "cuCtxSynchronize");

        // ========== 拷贝回CPU ==========
        FloatPointer fpC = new FloatPointer(hC);
        check(cudart.cuMemcpyDtoH(fpC, dC.get(), bytes), "cuMemcpyDtoH C");
        fpC.get(hC);

        // ========== 最终结果 ==========
        System.out.println("\n🎉 【完全成功】结果:");
        for (float v : hC) System.out.print(v + " ");
    }

    // 统一校验
    private static void check(int ret, String name) {
        if (ret == SUCCESS) {
            System.out.println("✅ " + name + " 成功");
        } else {
            System.err.println("❌ " + name + " 失败!错误码:" + ret);
            System.exit(1);
        }
    }
}


✅ cuInit 成功
✅ cuDeviceGet 成功
✅ cuCtxCreate 成功
✅ cuMemAlloc dA 成功
✅ cuMemAlloc dB 成功
✅ cuMemAlloc dC 成功
✅ cuMemcpyHtoD A 成功
✅ cuMemcpyHtoD B 成功
✅ nvrtcCreateProgram 成功
✅ nvrtcCompileProgram 成功
✅ nvrtcGetPTXSize 成功
✅ nvrtcGetPTX 成功
✅ 🔥 cuModuleLoadData 成功
✅ 🔥 cuModuleGetFunction 成功
✅ 🔥 cuLaunchKernel 成功
❌ cuMemcpyDtoH C 失败!错误码:700

进程已结束,退出代码为 1


三、踩坑指南(5 大核心坑)

坑 1:JDK 版本过低导致反射异常

现象:InaccessibleObjectException
原因:JDK 11 模块化限制
解决:升级到 JDK 17+

坑 2:Native 访问警告

现象:WARNING: restricted method System.load
解决:添加 JVM 参数 --enable-native-access=ALL-UNNAMED

坑 3:CUDA 版本不匹配

现象:Could not load library cudart
解决:系统 CUDA 驱动必须支持 13.1,Maven 依赖严格使用 13.1-9.19-1.5.13

坑 4:忘记释放显存

现象:显存泄漏、程序崩溃
解决:必须成对使用 cudaMalloc /cudaFree

坑 5:核函数参数类型错误

现象:结果乱码、崩溃
解决:严格匹配 FloatPointer/BytePointer 与核函数参数类型

四、课程总结

JavaCPP CUDA 13.1 可在 Java 中直接调用完整 CUDA Runtime API
核心流程:设备检测 → 内存分配 → 数据拷贝 → 核函数启动 → 结果回读
内存必须手动管理,主机 / 设备内存严格区分
错误处理与流并发是生产环境必备能力
可无缝对接 PyTorch Java、TensorFlow Java 等深度学习框架

五、扩展方向

批量矩阵运算(深度学习底层核心)
多 GPU 负载均衡
CUDA + PyTorch Java DDP 联合训练
图片 / 视频 GPU 加速处理

更多推荐