gpu-embedding-layer
GPUEmbeddingLayer 实现分析
概述
GPUEmbeddingLayer 是 NV Embedding Cache SDK 中最直接、最精简的嵌入层实现。它将所有嵌入数据存储在 GPU 线性内存(Linear Memory)中,查找操作完全在 GPU 上完成,不涉及主机(Host)回退。其核心思路是:如果完整嵌入表可以装入显存,那么直接在 GPU 上进行查找是最快的选择,无需任何缓存或层级回退。
该层位于 include/gpu_embedding_layer.hpp(声明)和 src/gpu_embedding_layer.cu(实现),模板参数为 KeyType,目前只支持 int32_t 和 int64_t 两种键类型。
设计原理
定位:无缓存的纯 GPU 嵌入层
与 HierarchicalEmbeddingLayer(三级缓存)和 LinearUVMEmbeddingLayer(GPU 缓存 + UVM 线性内存)不同,GPUEmbeddingLayer 不管理任何缓存。它假设用户已经将完整嵌入表放在了 GPU 显存中,通过 config.embedding_table 指针传入。查找时直接通过 cuEmbed 库的 EmbeddingForward kernel 从线性表中读取数据,写入输出缓冲区。
这对应了 README 中描述的第一种配置场景:
所有嵌入都分配在线性 GPU 内存中:使用
GPUEmbeddingLayer(C++)/ 带NoCache缓存类型的NVEmbedding(Python)。
适用场景
- 嵌入表规模较小,可以完整装入单张 GPU 显存
- 追求最低的查找延迟(无需任何 host 回退或缓存未命中处理)
- 推理场景的 baseline 配置,用于与其他缓存层做性能对比
配置结构体:GPUEmbeddingLayerConfig
struct GPUEmbeddingLayerConfig {
std::string layer_name; // 层名称(用于日志/调试)
int device_id{0}; // 使用的 GPU 设备编号
void* embedding_table; // GPU 线性内存中的嵌入表指针
int64_t num_embeddings; // 嵌入表行数
int64_t embedding_width_in_bytes; // 每行嵌入向量的字节宽度
DataType_t value_dtype; // 存储数据类型(仅用于 accumulate)
};
关键说明:
embedding_table必须是 GPU 设备内存指针,由用户预先分配和填充embedding_width_in_bytes目前只支持 fp16 和 fp32 类型,且必须能被 2 整除- 配置支持 JSON 序列化/反序列化(
from_json/to_json)
内部数据结构
template <typename KeyType>
class GPUEmbeddingLayer : public EmbeddingLayerBase {
private:
GPUEmbeddingLayerConfig config_; // 层配置
allocator_ptr_t allocator_; // 内存分配器
std::mutex kernel_launch_mutex_; // 内核启动互斥锁
std::shared_ptr<ContextRegistry> contexts_; // 执行上下文注册表
cudaEvent_t modify_in_progress_; // 修改操作进行中的 CUDA 事件
cudaStream_t private_modify_stream_; // 私有修改流
};
关键的同步基础设施:
kernel_launch_mutex_:保护所有 kernel 启动和 CUDA 事件排队操作,防止多个线程同时向 GPU 流提交命令contexts_:ContextRegistry跟踪所有存活的执行上下文,用于在修改操作时同步所有查找流modify_in_progress_:一个cudaEventDisableTiming事件,标记当前正在进行的修改操作。新发起的查找必须等待这个事件完成,保证读写一致性private_modify_stream_:一个专用于修改操作(update / accumulate)的 CUDA 流,与查找流的执行异步进行
核心操作实现
lookup() — 查找操作
lookup() 是 GPUEmbeddingLayer 最核心的操作,完全在 GPU 上执行。完整调用链如下:
用户调用 layer->lookup(ctx, num_keys, keys, output, ...)
│
├── ① ScopedDevice: 切换当前 CUDA context 到目标 GPU
│
├── ② BufferWrapper: 确保 keys 与 output 在 Device 内存中
│
├── ③ kernel_launch_mutex_.lock()
│ └── cudaStreamWaitEvent(lookup_stream, modify_in_progress_)
│ 确保所有之前的修改操作已完成
│
├── ④ 按 pooling 参数分流:
│ ├── 无 pool_params → cuembed_find<KeyType>()
│ └── 有 pool_params → cuembed_find_and_combine<KeyType>()
│
├── ⑤ 如果 output 原本在 Host 上 → cudaMemcpyAsync 回传
│
└── ⑥ hitrates[0] = 1.0f
① ScopedDevice — RAII 设备切换
class ScopedDevice {
ScopedDevice(int device_id) {
cudaGetDevice(&curr_device_);
if (device_id >= 0 && curr_device_ != device_id) {
cudaSetDevice(device_id);
}
}
~ScopedDevice() {
if (swap_device_) cudaSetDevice(curr_device_);
}
};
这是一个 RAII 辅助类,确保 CUDA kernel 在正确的 GPU 设备上启动。构造函数中记录当前设备,如需切换则调用 cudaSetDevice,析构时恢复。
② BufferWrapper — 透明内存适配器
BufferWrapper(include/buffer_wrapper.hpp)是 NVE 的透明内存适配器,解决了一个关键问题:用户传入的 keys/output 指针可能在 Host 内存、Device 内存或 Unified Memory 上,而 CUDA kernel 需要 Device 指针。其核心状态机:
构造时: 检测指针类型 (BufferType)
│
┌─────┴──────┐
│ buffers_ map│ key = cudaMemoryType
│ │ value = 内存指针
└─────────────┘
│
access_buffer(target_mem_type, copy_content, stream)
│
┌─────┴─────────────────────────────────────┐
│ 1. buffers_ 中是否已有 target 类型的副本? │
│ ├── 有 → 直接返回 │
│ └── 无 → ctx_->get_buffer() 分配 │
│ │
│ 2. copy_content = true? │
│ ├── 是 → cudaMemcpyAsync 从最后访问 │
│ │ 的缓冲区拷贝到目标缓冲区 │
│ │ (如果是 Host 目标还需 stream sync) │
│ └── 否 → 跳过拷贝 │
│ │
│ 3. last_access_ = target_mem_type │
└─────────────────────────────────────────────┘
关键设计:
last_access_记录最近一次访问的内存类型,作为默认的 copy source- const 缓冲区的优化:对于
const T*包装器,copy_content只在首次分配时有效,后续 access 不再重复拷贝 - Host 拷贝后同步:当目标内存类型是
cudaMemoryTypeHost时,拷贝后调用cudaStreamSynchronize(stream) - 内存所有权:分配的缓冲区的生命周期由
ExecutionContext管理,BufferWrapper 不负责释放
③ 内核启动锁与 modify 事件同步
std::lock_guard lock(kernel_launch_mutex_); // 保护 CUDA API 调用的原子性
NVE_CHECK_(cudaStreamWaitEvent(lookup_stream, modify_in_progress_));
kernel_launch_mutex_ 是一个 std::mutex,确保在多个线程同时调用 lookup/update/accumulate 时,CUDA 事件和 kernel launch 操作不会并发执行。cudaStreamWaitEvent 则在 GPU 硬件层面建立依赖——lookup stream 上的后续操作会等待 modify_in_progress_ 事件完成才执行,不阻塞 CPU 线程。
④ cuEmbed 的 EmbeddingForward kernel 分发
从 call_cuembed_forward() 进入 cuembed::EmbeddingForward(),这是一个多层模板分发的函数:
第一层:数据类型分发
switch (value_dtype) {
case Float32: EmbeddingForward<float, float, KeyType, int64_t, false>(...);
case Float16: EmbeddingForward<__half, __half, KeyType, int64_t, false>(...);
}
第二层:向量宽度分发 — DivideRowIntoVectors<ElemT>() 根据嵌入宽度决定每次加载的向量宽度:
if (bytes_per_row % 16 == 0) bytes_per_load = 16; // float4 / half4
else if (bytes_per_row % 8 == 0) bytes_per_load = 8;
else if (bytes_per_row % 4 == 0) bytes_per_load = 4;
第三层:合并模式 + 向量宽度组合 — 3 种模式 × 4 种宽度 = 12 个分支
第四层:热点数 + 权重/偏移组合 — 通过宏展开为 8 个分支
总共组合数:2(数据类型) × 12(模式×宽度) × 8(热点分支) = 192 种编译时特化的 kernel launch。
启动参数:
auto [element_per_load, threads_per_sample, samples_per_cta] =
GetKernelLaunchParams<ElemT, IndexT>(embed_width, num_hots, is_weighted);
dim3 launch_block(embed_width / element_per_load, samples_per_cta, 1);
dim3 launch_grid((batch_size + samples_per_cta - 1) / samples_per_cta, 1, 1);
size_t smem_size = samples_per_cta * num_hots * sizeof(IndexT);
- block.x = 每行需要的线程数(
embed_width / element_per_load) - block.y = 每个 CTA 处理的样本数(
samples_per_cta) - grid.x = 总 batch 数
⑤ pooling 中的 CSR 路径
当使用 CSR 格式时,num_hots=0 是 cuEmbed 约定:当 offsets != nullptr && num_hots == 0 时使用 CSR 语义,每个 bag 的边界由 offsets 数组指示。
⑥ hitrates
由于 GPU 表是线性表,所有合法键都能直接访问(无缓存穿透概念),所以命中率固定为 1.0f。
if (hitrates) hitrates[0] = 1.0f;
update() — 更新操作
update() 用于覆盖现有嵌入向量的值。完整同步序列:
update(ctx, num_keys, keys, ...)
│
├── BufferWrapper 确保 keys/values 在 Device
│
├── kernel_launch_mutex_.lock()
│
├── cudaStreamWaitEvent(private_modify_stream_, modify_in_progress_)
│ 等待之前所有的 modify 完成
│
├── StreamCoordinator sc(modify_stream, private_modify_stream_)
│ 构造时: modify_stream → private_modify_stream_ (事件依赖)
│
├── syncEvent = contexts_->create_sync_event()
│ syncEvent->event_record()
│ syncEvent->event_wait_stream(private_modify_stream_)
│ 等待所有进行中的查找完成
│
├── UpdateTable<KeyType>(..., private_modify_stream_)
│ 在私有修改流上执行更新 kernel
│
├── cudaEventRecord(modify_in_progress_, private_modify_stream_)
│ 标记修改完成
│
└── cudaStreamWaitEvent(modify_stream, modify_in_progress_)
~StreamCoordinator: private_modify_stream_ → modify_stream_
这里存在三重同步:
| 序号 | 同步原语 | 作用 |
|---|---|---|
| ① | cudaStreamWaitEvent(priv_mod, modify_in_progress) |
等待上一个修改完成(写后写串行化) |
| ② | syncEvent->event_wait_stream(priv_mod) + contexts_ 的所有查找流 |
等待所有查找流排空(读后写一致性) |
| ③ | cudaEventRecord(modify_in_progress, priv_mod) + 后续 wait |
通知后续查找/修改:当前修改已完成 |
UpdateTableKernel 是一个 CUDA kernel,位于 cuda_ops/update_accumulate.cuh。它使用 subwarp 粒度的并行策略:每个 subwarp(8 或 16 个线程)处理一个索引,线程间按元素并行,将源值拷贝到嵌入表的对应行。
accumulate() — 梯度累积操作
accumulate() 用于反向传播场景,将梯度值累加到嵌入向量上(而不是覆盖)。与 update() 相似的流程,但调用的是 UpdateAccumulateTableKernel,核心差异在于使用 atomicAdd 而非直接赋值:
for (int el = threadIdx.x; el < embed_width; el += SubwarpWidth) {
atomicAdd(embed_dst + el, embed_src[el]); // 原子累加
}
支持两种梯度精度:
| 梯度类型 | DataType | 原子操作粒度 |
|---|---|---|
| Float32 | float |
atomicAdd(float*) |
| Float16 | __half |
atomicAdd(half2*) |
insert() / erase() / clear()
这三个方法在 GPUEmbeddingLayer 中没有实际效果,只打印警告日志:
insert():输出 “insert method has no effect for GPU layer, use update to change table content”erase():输出 “erase method has no effect for GPU layer”clear():输出 “clear method has no effect for GPU layer”
原因在于该层只是对用户已经分配好的 GPU 线性内存的一个视图(View),不拥有嵌入表的所有权,也不管理其生命周期。用户通过 config.embedding_table 传入指针,NVE 层不对表内容做增删操作。
流同步与并发模型
GPUEmbeddingLayer 的并发模型围绕三个关键组件构建。
kernel_launch_mutex_
一个 std::mutex,保护所有 CUDA API 调用(cudaStreamWaitEvent、cudaEventRecord、kernel launch 等)。在 update() 和 accumulate() 中与 StreamCoordinator 配合使用,确保事件记录和等待的原子性。
modify_in_progress_ 事件
一个 cudaEventDisableTiming 事件,标记最近一次修改操作的完成时间点:
update()/accumulate()提交时:在修改 stream 上记录事件lookup()发起时:在 lookup stream 上等待该事件 → 保证读后写(read-after-write)一致性- 下一个
update()/accumulate()发起时:也在修改 stream 上等待该事件 → 保证写后写(write-after-write)顺序
ContextRegistry 同步
在 update() / accumulate() 中,修改 kernel 启动之前会通过 contexts_->create_sync_event() 创建一个同步事件,该事件会在所有已知的 lookup stream 上等待,确保没有正在执行的查找操作读取即将被修改的数据。
StreamCoordinator
StreamCoordinator 是一个 RAII 辅助类,在构造和析构时在两个 CUDA 流之间建立事件依赖:
StreamCoordinator sc(modify_stream, private_modify_stream_);
// 构造时:modify_stream → private_modify_stream_
// 析构时:private_modify_stream_ → modify_stream_
create_stream_dependency 使用临时 cudaEvent(cudaEventDisableTiming 轻量事件,创建开销 ~0.21μs):
static void create_stream_dependency(const cudaStream_t& src,
const cudaStream_t& dst) {
cudaEvent_t e;
cudaEventCreateWithFlags(&e, cudaEventDisableTiming);
cudaEventRecord(e, src);
cudaStreamWaitEvent(dst, e);
cudaEventDestroy(e);
}
线程安全模型总结
Thread 1 Thread 2
│ │
lookup(ctx1) lookup(ctx2)
│ │
lock(mutex) lock(mutex)
│ │
waitEvent(modify_in_prog) waitEvent(modify_in_prog)
│ │
launch kernel launch kernel
│ │
unlock(mutex) unlock(mutex)
│ │
lookup 完成 lookup 完成 (并发执行)
│ │
├───────────────────────┤
│
update(ctx1)
│
lock(mutex)
│
waitEvent(modify_in_prog)
│ ┌─ contexts_ sync event
│ │ (等待所有 lookup stream 排空)
│ └─ StreamCoordinator
│
launch UpdateTable
│
recordEvent(modify_in_prog)
│
unlock(mutex)
- 多个 lookup 可以并发:在不同的
GPUEmbeddingTableExecutionContext上 - update/accumulate 与 lookup 间:通过
modify_in_progress_event 实现 GPU 级异步同步 - update/accumulate 间的串行化:通过 waitEvent + recordEvent 保证写后写顺序
kernel_launch_mutex_:保护 CUDA API 调用的原子性,防止多线程同时创建/销毁 event
Update / Accumulate Kernel 的 Subwarp 并行
UpdateTableKernel
template<uint32_t SubwarpWidth, typename KeyType, typename DataType>
__global__ void UpdateTableKernel(...) {
const int id = blockIdx.x * blockDim.y + threadIdx.y; // 全局索引 ID
if (id >= num_indices) return;
KeyType key = indices[id];
const DataType* embed_src = src + id * embed_src_stride_in_bytes;
DataType* embed_dst = embedding_table + key * embed_dst_stride_in_bytes;
// subwarp 内按元素并行
for (int el = threadIdx.x; el < num_elements; el += SubwarpWidth) {
embed_dst[el] = embed_src[el];
}
}
启动参数:grid_size = (num_indices + indices_per_warp - 1) / indices_per_warp,block_size = (SubwarpWidth, indices_per_warp),其中 indices_per_warp = 32 / SubwarpWidth。
以 64 维 fp32 嵌入为例:SubwarpWidth = 32,indices_per_warp = 1,block_size = (32, 1)。每个 block 使用 32 个线程,每个线程处理 2 个元素,一个 block 处理 1 个索引。
SubwarpWidth 的编译时选择
uint32_t subgroupWidth = std::min(nextPow2(embed_width / sizeof(DataType)), 32u);
switch (subgroupWidth) {
case 32: CallUpdateKernelVecTypeSubwarp<32, ...>(...); break;
case 16: ...;
case 8: ...;
case 4: ...;
case 2: ...;
case 1: ...;
}
nextPow2 返回大于等于输入的最小 2 的幂,截断到 32(warp 大小),确保 subwarp 宽度与嵌入维度对齐。
向量化加载的 dispatch
if (embed_width_in_bytes % 16 == 0)
CallUpdatKernelVecType<KeyType, Vec4>(...); // float4 / half4
else if (embed_width_in_bytes % 8 == 0)
CallUpdatKernelVecType<KeyType, Vec2>(...); // float2 / half2
else if (embed_width_in_bytes % 4 == 0)
CallUpdatKernelVecType<KeyType, Vec1>(...); // float
通过 VecWidthHelper 模板特化,将标量类型映射到向量类型。向量化加载能让编译器生成 LDG.E.128 / STG.E.128 指令,一次加载 16 字节,最大化内存带宽利用率。
GPUEmbeddingTableExecutionContext
GPUEmbeddingTableExecutionContext 是 ExecutionContext 的定制子类,每个执行上下文包含一对 CUDA 流(lookup_stream 和 modify_stream),多个上下文可以并发执行查找。
class GPUEmbeddingTableExecutionContext : public ExecutionContext {
GPUEmbeddingTableExecutionContext(...)
: ... {
context_registry_->add_context(this); // 注册到 registry
}
~GPUEmbeddingTableExecutionContext() {
cudaStreamSynchronize(lookup_stream_);
cudaStreamSynchronize(modify_stream_);
context_registry_->remove_context(this); // 注销
}
};
关键点:
- 构造时自动注册:将自身加入
ContextRegistry,使修改操作能同步到本上下文 - 析构时自动同步:确保所有流上的待处理工作完成后再销毁
- 缓冲区生命周期:从
ExecutionContext继承的buffer_storage_管理所有临时缓冲区
cuEmbed EmbeddingForward Kernel 的内部结构
cuEmbed 的 EmbeddingLookUpKernel 的核心逻辑是 共享内存协作式查找:
对于每个 CTA 处理的 samples_per_cta 个样本:
│
├── 1. 从 global memory 加载索引到共享内存
│
├── 2. 屏障同步 (__syncthreads)
│
├── 3. 对于嵌入向量中的每个元素 (由 threadIdx.x 负责):
│ a. 从 params + indices[sample * num_hots + hot] * embed_width 读取
│ b. 如果有权重: 读取权重并 MulAccumulate
│ c. 否则: 直接 Accumulate
│ d. 遍历 num_hots 个热点
│
├── 4. 如果 mode == kMean: 除以 num_hots
│
└── 5. 写入输出 ret + sample * output_width
Addresser 模板负责地址计算,Combiner 模板负责累加/拼接/均值计算,IndexLoader 模板负责从共享内存加载索引。三者的编译时特化组合构成了全部 192 种 kernel 变体。对于 kConcat 模式,kernel 跳过累加阶段,直接将每个热点的嵌入向量连续写入输出缓冲区。
与其余嵌入层的对比
| 特性 | GPUEmbeddingLayer | LinearUVMEmbeddingLayer | HierarchicalEmbeddingLayer |
|---|---|---|---|
| 存储位置 | 全 GPU 显存 | GPU 缓存 + UVM 线性内存 | GPU 缓存 + CPU 内存 + 远程 |
| 缓存机制 | 无 | GPU 集合关联缓存 | 三级缓存层级 |
| 查找路径 | GPU kernel 直读 | GPU cache → UVM fallback | GPU → CPU → Remote |
| host fallback | 无 | 有(UVM 缺页) | 有(CPU 表 + 远程表) |
| insert/erase | 无效果(warn) | 有 | 有 |
| 适用场景 | 小表,全显存 | 中表,超显存但可 UVM | 大表,远超显存 |
| 查找延迟 | 最低 | 中等(cache hit 时低) | 较高(miss 时回退多) |
| lookup kernel | cuEmbed EmbeddingForward | EmbedCacheSA + cuEmbed | EmbedCacheSA + CPU gather |
| 映射关系 | 1:1 直接寻址 | cache tag + UVM page | cache tag + hash table + remote |
| 命中率报告 | 固定 1.0 | 需 stream sync 统计 | 需 stream sync 统计 |
| 同步复杂度 | 最低 | 中等 | 最高 |
总结
GPUEmbeddingLayer 是 NVE 中最简单、最直接的嵌入层实现。它假定嵌入表已完整驻留在 GPU 显存中,通过 cuEmbed 的优化 kernel 实现零回退的极速查找。其代码量虽小(约 250 行 CUDA C++),但展示了完整的 CUDA 流同步模式(kernel_launch_mutex_ + modify_in_progress_ event + ContextRegistry + StreamCoordinator),为更复杂的 LinearUVMEmbeddingLayer 和 HierarchicalEmbeddingLayer 奠定了基础。
从底层实现来看,GPUEmbeddingLayer 本质上是一个直接索引 + 向量加载的 CUDA kernel wrapper。其查找路径涵盖了 ScopedDevice 设备管理、BufferWrapper 透明内存适配、cuEmbed 的 192 种 kernel 特化分发、以及多级 CUDA 流同步等关键技术点。对于嵌入表可以完整装入显存的场景(如 Criteo 数据集的小型模型或精排阶段),GPUEmbeddingLayer 是性能最优的选择。