gpu-embedding-layer

GPUEmbeddingLayer 实现分析

概述

GPUEmbeddingLayer 是 NV Embedding Cache SDK 中最直接、最精简的嵌入层实现。它将所有嵌入数据存储在 GPU 线性内存(Linear Memory)中,查找操作完全在 GPU 上完成,不涉及主机(Host)回退。其核心思路是:如果完整嵌入表可以装入显存,那么直接在 GPU 上进行查找是最快的选择,无需任何缓存或层级回退。

该层位于 include/gpu_embedding_layer.hpp(声明)和 src/gpu_embedding_layer.cu(实现),模板参数为 KeyType,目前只支持 int32_tint64_t 两种键类型。


设计原理

定位:无缓存的纯 GPU 嵌入层

HierarchicalEmbeddingLayer(三级缓存)和 LinearUVMEmbeddingLayer(GPU 缓存 + UVM 线性内存)不同,GPUEmbeddingLayer 不管理任何缓存。它假设用户已经将完整嵌入表放在了 GPU 显存中,通过 config.embedding_table 指针传入。查找时直接通过 cuEmbed 库的 EmbeddingForward kernel 从线性表中读取数据,写入输出缓冲区。

这对应了 README 中描述的第一种配置场景:

所有嵌入都分配在线性 GPU 内存中:使用 GPUEmbeddingLayer(C++)/ 带 NoCache 缓存类型的 NVEmbedding(Python)。

适用场景

  • 嵌入表规模较小,可以完整装入单张 GPU 显存
  • 追求最低的查找延迟(无需任何 host 回退或缓存未命中处理)
  • 推理场景的 baseline 配置,用于与其他缓存层做性能对比

配置结构体:GPUEmbeddingLayerConfig

struct GPUEmbeddingLayerConfig {
  std::string layer_name;             // 层名称(用于日志/调试)
  int device_id{0};                   // 使用的 GPU 设备编号
  void* embedding_table;              // GPU 线性内存中的嵌入表指针
  int64_t num_embeddings;             // 嵌入表行数
  int64_t embedding_width_in_bytes;   // 每行嵌入向量的字节宽度
  DataType_t value_dtype;             // 存储数据类型(仅用于 accumulate)
};

关键说明:

  • embedding_table 必须是 GPU 设备内存指针,由用户预先分配和填充
  • embedding_width_in_bytes 目前只支持 fp16 和 fp32 类型,且必须能被 2 整除
  • 配置支持 JSON 序列化/反序列化(from_json / to_json

内部数据结构

template <typename KeyType>
class GPUEmbeddingLayer : public EmbeddingLayerBase {
 private:
  GPUEmbeddingLayerConfig config_;    // 层配置
  allocator_ptr_t allocator_;         // 内存分配器

  std::mutex kernel_launch_mutex_;     // 内核启动互斥锁
  std::shared_ptr<ContextRegistry> contexts_;  // 执行上下文注册表
  cudaEvent_t modify_in_progress_;     // 修改操作进行中的 CUDA 事件
  cudaStream_t private_modify_stream_; // 私有修改流
};

关键的同步基础设施:

  • kernel_launch_mutex_:保护所有 kernel 启动和 CUDA 事件排队操作,防止多个线程同时向 GPU 流提交命令
  • contexts_ContextRegistry 跟踪所有存活的执行上下文,用于在修改操作时同步所有查找流
  • modify_in_progress_:一个 cudaEventDisableTiming 事件,标记当前正在进行的修改操作。新发起的查找必须等待这个事件完成,保证读写一致性
  • private_modify_stream_:一个专用于修改操作(update / accumulate)的 CUDA 流,与查找流的执行异步进行

核心操作实现

lookup() — 查找操作

lookup() 是 GPUEmbeddingLayer 最核心的操作,完全在 GPU 上执行。完整调用链如下:

用户调用 layer->lookup(ctx, num_keys, keys, output, ...)
  │
  ├── ① ScopedDevice: 切换当前 CUDA context 到目标 GPU
  │
  ├── ② BufferWrapper: 确保 keys 与 output 在 Device 内存中
  │
  ├── ③ kernel_launch_mutex_.lock()
  │      └── cudaStreamWaitEvent(lookup_stream, modify_in_progress_)
  │          确保所有之前的修改操作已完成
  │
  ├── ④ 按 pooling 参数分流:
  │     ├── 无 pool_params → cuembed_find<KeyType>()
  │     └── 有 pool_params → cuembed_find_and_combine<KeyType>()
  │
  ├── ⑤ 如果 output 原本在 Host 上 → cudaMemcpyAsync 回传
  │
  └── ⑥ hitrates[0] = 1.0f

① ScopedDevice — RAII 设备切换

class ScopedDevice {
    ScopedDevice(int device_id) {
        cudaGetDevice(&curr_device_);
        if (device_id >= 0 && curr_device_ != device_id) {
            cudaSetDevice(device_id);
        }
    }
    ~ScopedDevice() {
        if (swap_device_) cudaSetDevice(curr_device_);
    }
};

这是一个 RAII 辅助类,确保 CUDA kernel 在正确的 GPU 设备上启动。构造函数中记录当前设备,如需切换则调用 cudaSetDevice,析构时恢复。

② BufferWrapper — 透明内存适配器

BufferWrapper(include/buffer_wrapper.hpp)是 NVE 的透明内存适配器,解决了一个关键问题:用户传入的 keys/output 指针可能在 Host 内存、Device 内存或 Unified Memory 上,而 CUDA kernel 需要 Device 指针。其核心状态机:

构造时: 检测指针类型 (BufferType)
       │
 ┌─────┴──────┐
 │ buffers_ map│  key = cudaMemoryType
 │             │  value = 内存指针
 └─────────────┘
       │
 access_buffer(target_mem_type, copy_content, stream)
       │
 ┌─────┴─────────────────────────────────────┐
 │  1. buffers_ 中是否已有 target 类型的副本?  │
 │     ├── 有 → 直接返回                      │
 │     └── 无 → ctx_->get_buffer() 分配       │
 │                                             │
 │  2. copy_content = true?                    │
 │     ├── 是 → cudaMemcpyAsync 从最后访问      │
 │     │        的缓冲区拷贝到目标缓冲区          │
 │     │    (如果是 Host 目标还需 stream sync)   │
 │     └── 否 → 跳过拷贝                       │
 │                                             │
 │  3. last_access_ = target_mem_type           │
 └─────────────────────────────────────────────┘

关键设计:

  • last_access_ 记录最近一次访问的内存类型,作为默认的 copy source
  • const 缓冲区的优化:对于 const T* 包装器,copy_content 只在首次分配时有效,后续 access 不再重复拷贝
  • Host 拷贝后同步:当目标内存类型是 cudaMemoryTypeHost 时,拷贝后调用 cudaStreamSynchronize(stream)
  • 内存所有权:分配的缓冲区的生命周期由 ExecutionContext 管理,BufferWrapper 不负责释放

③ 内核启动锁与 modify 事件同步

std::lock_guard lock(kernel_launch_mutex_);  // 保护 CUDA API 调用的原子性
NVE_CHECK_(cudaStreamWaitEvent(lookup_stream, modify_in_progress_));

kernel_launch_mutex_ 是一个 std::mutex,确保在多个线程同时调用 lookup/update/accumulate 时,CUDA 事件和 kernel launch 操作不会并发执行。cudaStreamWaitEvent 则在 GPU 硬件层面建立依赖——lookup stream 上的后续操作会等待 modify_in_progress_ 事件完成才执行,不阻塞 CPU 线程

④ cuEmbed 的 EmbeddingForward kernel 分发

call_cuembed_forward() 进入 cuembed::EmbeddingForward(),这是一个多层模板分发的函数:

第一层:数据类型分发

switch (value_dtype) {
    case Float32:  EmbeddingForward<float, float, KeyType, int64_t, false>(...);
    case Float16:  EmbeddingForward<__half, __half, KeyType, int64_t, false>(...);
}

第二层:向量宽度分发DivideRowIntoVectors<ElemT>() 根据嵌入宽度决定每次加载的向量宽度:

if (bytes_per_row % 16 == 0) bytes_per_load = 16;  // float4 / half4
else if (bytes_per_row % 8 == 0) bytes_per_load = 8;
else if (bytes_per_row % 4 == 0) bytes_per_load = 4;

第三层:合并模式 + 向量宽度组合 — 3 种模式 × 4 种宽度 = 12 个分支

第四层:热点数 + 权重/偏移组合 — 通过宏展开为 8 个分支

总共组合数:2(数据类型) × 12(模式×宽度) × 8(热点分支) = 192 种编译时特化的 kernel launch

启动参数:

auto [element_per_load, threads_per_sample, samples_per_cta] =
    GetKernelLaunchParams<ElemT, IndexT>(embed_width, num_hots, is_weighted);

dim3 launch_block(embed_width / element_per_load, samples_per_cta, 1);
dim3 launch_grid((batch_size + samples_per_cta - 1) / samples_per_cta, 1, 1);
size_t smem_size = samples_per_cta * num_hots * sizeof(IndexT);
  • block.x = 每行需要的线程数(embed_width / element_per_load
  • block.y = 每个 CTA 处理的样本数(samples_per_cta
  • grid.x = 总 batch 数

⑤ pooling 中的 CSR 路径

当使用 CSR 格式时,num_hots=0 是 cuEmbed 约定:当 offsets != nullptr && num_hots == 0 时使用 CSR 语义,每个 bag 的边界由 offsets 数组指示。

⑥ hitrates

由于 GPU 表是线性表,所有合法键都能直接访问(无缓存穿透概念),所以命中率固定为 1.0f

if (hitrates) hitrates[0] = 1.0f;

update() — 更新操作

update() 用于覆盖现有嵌入向量的值。完整同步序列:

update(ctx, num_keys, keys, ...)
  │
  ├── BufferWrapper 确保 keys/values 在 Device
  │
  ├── kernel_launch_mutex_.lock()
  │
  ├── cudaStreamWaitEvent(private_modify_stream_, modify_in_progress_)
  │       等待之前所有的 modify 完成
  │
  ├── StreamCoordinator sc(modify_stream, private_modify_stream_)
  │       构造时: modify_stream → private_modify_stream_ (事件依赖)
  │
  ├── syncEvent = contexts_->create_sync_event()
  │   syncEvent->event_record()
  │   syncEvent->event_wait_stream(private_modify_stream_)
  │       等待所有进行中的查找完成
  │
  ├── UpdateTable<KeyType>(..., private_modify_stream_)
  │       在私有修改流上执行更新 kernel
  │
  ├── cudaEventRecord(modify_in_progress_, private_modify_stream_)
  │       标记修改完成
  │
  └── cudaStreamWaitEvent(modify_stream, modify_in_progress_)
         ~StreamCoordinator: private_modify_stream_ → modify_stream_

这里存在三重同步

序号 同步原语 作用
cudaStreamWaitEvent(priv_mod, modify_in_progress) 等待上一个修改完成(写后写串行化)
syncEvent->event_wait_stream(priv_mod) + contexts_ 的所有查找流 等待所有查找流排空(读后写一致性)
cudaEventRecord(modify_in_progress, priv_mod) + 后续 wait 通知后续查找/修改:当前修改已完成

UpdateTableKernel 是一个 CUDA kernel,位于 cuda_ops/update_accumulate.cuh。它使用 subwarp 粒度的并行策略:每个 subwarp(8 或 16 个线程)处理一个索引,线程间按元素并行,将源值拷贝到嵌入表的对应行。

accumulate() — 梯度累积操作

accumulate() 用于反向传播场景,将梯度值累加到嵌入向量上(而不是覆盖)。与 update() 相似的流程,但调用的是 UpdateAccumulateTableKernel,核心差异在于使用 atomicAdd 而非直接赋值:

for (int el = threadIdx.x; el < embed_width; el += SubwarpWidth) {
    atomicAdd(embed_dst + el, embed_src[el]);  // 原子累加
}

支持两种梯度精度:

梯度类型 DataType 原子操作粒度
Float32 float atomicAdd(float*)
Float16 __half atomicAdd(half2*)

insert() / erase() / clear()

这三个方法在 GPUEmbeddingLayer 中没有实际效果,只打印警告日志:

  • insert():输出 “insert method has no effect for GPU layer, use update to change table content”
  • erase():输出 “erase method has no effect for GPU layer”
  • clear():输出 “clear method has no effect for GPU layer”

原因在于该层只是对用户已经分配好的 GPU 线性内存的一个视图(View),不拥有嵌入表的所有权,也不管理其生命周期。用户通过 config.embedding_table 传入指针,NVE 层不对表内容做增删操作。


流同步与并发模型

GPUEmbeddingLayer 的并发模型围绕三个关键组件构建。

kernel_launch_mutex_

一个 std::mutex,保护所有 CUDA API 调用(cudaStreamWaitEventcudaEventRecord、kernel launch 等)。在 update()accumulate() 中与 StreamCoordinator 配合使用,确保事件记录和等待的原子性。

modify_in_progress_ 事件

一个 cudaEventDisableTiming 事件,标记最近一次修改操作的完成时间点:

  • update() / accumulate() 提交时:在修改 stream 上记录事件
  • lookup() 发起时:在 lookup stream 上等待该事件 → 保证读后写(read-after-write)一致性
  • 下一个 update() / accumulate() 发起时:也在修改 stream 上等待该事件 → 保证写后写(write-after-write)顺序

ContextRegistry 同步

update() / accumulate() 中,修改 kernel 启动之前会通过 contexts_->create_sync_event() 创建一个同步事件,该事件会在所有已知的 lookup stream 上等待,确保没有正在执行的查找操作读取即将被修改的数据。

StreamCoordinator

StreamCoordinator 是一个 RAII 辅助类,在构造和析构时在两个 CUDA 流之间建立事件依赖:

StreamCoordinator sc(modify_stream, private_modify_stream_);
// 构造时:modify_stream → private_modify_stream_
// 析构时:private_modify_stream_ → modify_stream_

create_stream_dependency 使用临时 cudaEventcudaEventDisableTiming 轻量事件,创建开销 ~0.21μs):

static void create_stream_dependency(const cudaStream_t& src,
                                      const cudaStream_t& dst) {
    cudaEvent_t e;
    cudaEventCreateWithFlags(&e, cudaEventDisableTiming);
    cudaEventRecord(e, src);
    cudaStreamWaitEvent(dst, e);
    cudaEventDestroy(e);
}

线程安全模型总结

       Thread 1               Thread 2
          │                       │
    lookup(ctx1)             lookup(ctx2)
          │                       │
    lock(mutex)              lock(mutex)
          │                       │
waitEvent(modify_in_prog)   waitEvent(modify_in_prog)
          │                       │
    launch kernel            launch kernel
          │                       │
  unlock(mutex)             unlock(mutex)
          │                       │
    lookup 完成              lookup 完成 (并发执行)
          │                       │
          ├───────────────────────┤
                  │
             update(ctx1)
                  │
             lock(mutex)
                  │
         waitEvent(modify_in_prog)
                  │  ┌─ contexts_ sync event
                  │  │  (等待所有 lookup stream 排空)
                  │  └─ StreamCoordinator
                  │
        launch UpdateTable
                  │
        recordEvent(modify_in_prog)
                  │
             unlock(mutex)
  • 多个 lookup 可以并发:在不同的 GPUEmbeddingTableExecutionContext
  • update/accumulate 与 lookup 间:通过 modify_in_progress_ event 实现 GPU 级异步同步
  • update/accumulate 间的串行化:通过 waitEvent + recordEvent 保证写后写顺序
  • kernel_launch_mutex_:保护 CUDA API 调用的原子性,防止多线程同时创建/销毁 event

Update / Accumulate Kernel 的 Subwarp 并行

UpdateTableKernel

template<uint32_t SubwarpWidth, typename KeyType, typename DataType>
__global__ void UpdateTableKernel(...) {
    const int id = blockIdx.x * blockDim.y + threadIdx.y;  // 全局索引 ID
    if (id >= num_indices) return;

    KeyType key = indices[id];
    const DataType* embed_src = src + id * embed_src_stride_in_bytes;
    DataType* embed_dst = embedding_table + key * embed_dst_stride_in_bytes;

    // subwarp 内按元素并行
    for (int el = threadIdx.x; el < num_elements; el += SubwarpWidth) {
        embed_dst[el] = embed_src[el];
    }
}

启动参数:grid_size = (num_indices + indices_per_warp - 1) / indices_per_warpblock_size = (SubwarpWidth, indices_per_warp),其中 indices_per_warp = 32 / SubwarpWidth

以 64 维 fp32 嵌入为例:SubwarpWidth = 32indices_per_warp = 1block_size = (32, 1)。每个 block 使用 32 个线程,每个线程处理 2 个元素,一个 block 处理 1 个索引。

SubwarpWidth 的编译时选择

uint32_t subgroupWidth = std::min(nextPow2(embed_width / sizeof(DataType)), 32u);
switch (subgroupWidth) {
    case 32: CallUpdateKernelVecTypeSubwarp<32, ...>(...); break;
    case 16: ...;
    case 8:  ...;
    case 4:  ...;
    case 2:  ...;
    case 1:  ...;
}

nextPow2 返回大于等于输入的最小 2 的幂,截断到 32(warp 大小),确保 subwarp 宽度与嵌入维度对齐。

向量化加载的 dispatch

if (embed_width_in_bytes % 16 == 0)
    CallUpdatKernelVecType<KeyType, Vec4>(...);   // float4 / half4
else if (embed_width_in_bytes % 8 == 0)
    CallUpdatKernelVecType<KeyType, Vec2>(...);   // float2 / half2
else if (embed_width_in_bytes % 4 == 0)
    CallUpdatKernelVecType<KeyType, Vec1>(...);   // float

通过 VecWidthHelper 模板特化,将标量类型映射到向量类型。向量化加载能让编译器生成 LDG.E.128 / STG.E.128 指令,一次加载 16 字节,最大化内存带宽利用率。


GPUEmbeddingTableExecutionContext

GPUEmbeddingTableExecutionContextExecutionContext 的定制子类,每个执行上下文包含一对 CUDA 流(lookup_streammodify_stream),多个上下文可以并发执行查找。

class GPUEmbeddingTableExecutionContext : public ExecutionContext {
    GPUEmbeddingTableExecutionContext(...)
      : ... {
        context_registry_->add_context(this);  // 注册到 registry
    }

    ~GPUEmbeddingTableExecutionContext() {
        cudaStreamSynchronize(lookup_stream_);
        cudaStreamSynchronize(modify_stream_);
        context_registry_->remove_context(this);  // 注销
    }
};

关键点:

  • 构造时自动注册:将自身加入 ContextRegistry,使修改操作能同步到本上下文
  • 析构时自动同步:确保所有流上的待处理工作完成后再销毁
  • 缓冲区生命周期:从 ExecutionContext 继承的 buffer_storage_ 管理所有临时缓冲区

cuEmbed EmbeddingForward Kernel 的内部结构

cuEmbed 的 EmbeddingLookUpKernel 的核心逻辑是 共享内存协作式查找

对于每个 CTA 处理的 samples_per_cta 个样本:
  │
  ├── 1. 从 global memory 加载索引到共享内存
  │
  ├── 2. 屏障同步 (__syncthreads)
  │
  ├── 3. 对于嵌入向量中的每个元素 (由 threadIdx.x 负责):
  │      a. 从 params + indices[sample * num_hots + hot] * embed_width 读取
  │      b. 如果有权重: 读取权重并 MulAccumulate
  │      c. 否则: 直接 Accumulate
  │      d. 遍历 num_hots 个热点
  │
  ├── 4. 如果 mode == kMean: 除以 num_hots
  │
  └── 5. 写入输出 ret + sample * output_width

Addresser 模板负责地址计算,Combiner 模板负责累加/拼接/均值计算,IndexLoader 模板负责从共享内存加载索引。三者的编译时特化组合构成了全部 192 种 kernel 变体。对于 kConcat 模式,kernel 跳过累加阶段,直接将每个热点的嵌入向量连续写入输出缓冲区。


与其余嵌入层的对比

特性 GPUEmbeddingLayer LinearUVMEmbeddingLayer HierarchicalEmbeddingLayer
存储位置 全 GPU 显存 GPU 缓存 + UVM 线性内存 GPU 缓存 + CPU 内存 + 远程
缓存机制 GPU 集合关联缓存 三级缓存层级
查找路径 GPU kernel 直读 GPU cache → UVM fallback GPU → CPU → Remote
host fallback 有(UVM 缺页) 有(CPU 表 + 远程表)
insert/erase 无效果(warn)
适用场景 小表,全显存 中表,超显存但可 UVM 大表,远超显存
查找延迟 最低 中等(cache hit 时低) 较高(miss 时回退多)
lookup kernel cuEmbed EmbeddingForward EmbedCacheSA + cuEmbed EmbedCacheSA + CPU gather
映射关系 1:1 直接寻址 cache tag + UVM page cache tag + hash table + remote
命中率报告 固定 1.0 需 stream sync 统计 需 stream sync 统计
同步复杂度 最低 中等 最高

总结

GPUEmbeddingLayer 是 NVE 中最简单、最直接的嵌入层实现。它假定嵌入表已完整驻留在 GPU 显存中,通过 cuEmbed 的优化 kernel 实现零回退的极速查找。其代码量虽小(约 250 行 CUDA C++),但展示了完整的 CUDA 流同步模式(kernel_launch_mutex_ + modify_in_progress_ event + ContextRegistry + StreamCoordinator),为更复杂的 LinearUVMEmbeddingLayerHierarchicalEmbeddingLayer 奠定了基础。

从底层实现来看,GPUEmbeddingLayer 本质上是一个直接索引 + 向量加载的 CUDA kernel wrapper。其查找路径涵盖了 ScopedDevice 设备管理、BufferWrapper 透明内存适配、cuEmbed 的 192 种 kernel 特化分发、以及多级 CUDA 流同步等关键技术点。对于嵌入表可以完整装入显存的场景(如 Criteo 数据集的小型模型或精排阶段),GPUEmbeddingLayer 是性能最优的选择。