3 분 소요

논문 링크: Vectorizing the Trie: Efficient Constrained Decoding for LLM-based Generative Retrieval on Accelerators

“방금 추천받은 영상을 클릭했는데 삭제된 영상이라면?”

추천 시스템에서 모델의 정확도만큼이나 중요한 것이 바로 ‘유효성(Validity)’입니다. 최근 YouTube, Amazon 등 거대 테크 기업들은 기존의 임베딩 기반 검색 모델에서 LLM을 활용한 생성 검색(Generative Retrieval)으로 패러다임을 전환하고 있습니다.

하지만 LLM은 태생적으로 치명적인 약점을 가지고 있습니다. 바로 존재하지 않거나 조건에 맞지 않는 아이템을 추천하는 환각(Hallucination) 현상입니다. 이를 해결하기 위한 YouTube의 혁신적인 디코딩 가속화 기술, STATIC을 소개합니다.


1. 생성 검색과 ‘유효성 갭(Validity Gap)’

생성 검색 모델은 각 아이템의 의미론적 식별자(Semantic ID)를 토큰 단위로 직접 자기회귀적으로 디코딩합니다. 이 방식은 별도의 외부 근사 최근접 이웃 인프라 없이도 아이템 간의 깊은 관계를 포착할 수 있습니다. 하지만 비즈니스 로직(예: 최근 7일 이내 영상, 재고 있음, 지역 제한)을 강제하려면 제약 디코딩(Constrained Decoding)이 필수적입니다.

사후에 필터링을 거치게 되면 LLM이 이미 무효한 아이템을 생성하는 데 추론 예산을 모두 낭비할 수 있으므로 매우 비효율적입니다. 따라서 디코딩 과정 자체에서 생성 가능한 토큰의 범위를 강제로 제한해야 합니다.


2. 기존 Trie 기반 제약 디코딩의 한계

유효하지 않은 토큰을 마스킹하기 위해 가장 널리 쓰이는 표준적인 해결책은 접두사 트리(Trie) 자료구조를 활용하는 것입니다. 하지만 이 방식은 최신 AI 가속기(GPU/TPU) 환경과 최악의 궁합을 보입니다.

  • 메모리 병목: 포인터 기반의 트리 탐색은 불규칙하고 비연속적인 메모리 접근을 유발하여, 최신 가속기의 고대역폭 메모리(HBM) 활용을 떨어뜨립니다.
  • 컴파일 비호환성: TPU나 최신 GPU는 머신러닝 컴파일(예: XLA)을 위해 정적인 연산 그래프를 요구하지만, 데이터 의존적인 제어 흐름을 가진 Trie는 근본적으로 호환되지 않습니다.
  • 속도 저하: 연산을 CPU로 넘길 경우 추론 시간이 2배로 늘어나며, 하드웨어 가속 이진 탐색 기법(PPV)들 역시 제약 세트 크기에 따라 대수적($O(\log C )$)으로 확장되는 I/O 병목이 발생합니다.

3. STATIC의 혁신: 트리를 부수고 행렬을 곱하다

{00F64756-7D39-4459-A65F-0B5085CD4F5E}

STATIC(Sparse Transition Matrix-Accelerated Trie Index for Constrained Decoding)의 핵심 아이디어는 불규칙한 그래프 탐색 문제를 완전히 벡터화된 희소 행렬 연산으로 재구성하는 것입니다.

💡 핵심 기술 3가지

  • 희소 전이 행렬 (Sparse Transition Matrix, STM): Trie 구조를 오프라인에서 정적인 CSR(Compressed Sparse Row) 행렬로 평탄화합니다. 이를 통해 단일 병합 읽기(Coalesced Read)로 $O(1)$의 메모리 접근 오버헤드만 발생하며 빠른 제약 추출이 가능해집니다.
  • 분기 없는 커널 설계 (Vectorized Node Transition Kernel, VNTK): GPU 워프 발산이나 TPU 그래프 재컴파일 문제를 피하기 위해 동적 분기를 제거했습니다. 노드마다 자식의 개수가 달라도 항상 고정된 크기($B_l$)만큼 데이터를 선제적으로 슬라이싱하여 가져오고, 유효하지 않은 데이터는 마스크(Mask) 처리하여 병렬 처리 효율을 극대화합니다.
  • 하이브리드 마스킹 아키텍처: 노드가 빽빽하게 밀집된 초기 레벨(첫 $d$개의 디코딩 단계)에서는 고밀도(Dense) 텐서 마스크를 통해 즉각적인 조회를 수행하고, 깊은 하위 레벨에서만 VNTK를 활용한 희소 행렬 연산을 적용해 속도를 최적화했습니다.

4. 압도적인 성능 벤치마크

YouTube의 대규모 비디오 추천 환경(2,000만 개의 신선한 영상 제약 조건)에서 STATIC을 테스트한 결과는 놀랍습니다.

{F66EFA17-AB11-4B61-B98A-D76B11931E6B}

  • 디코딩 지연 시간: 단계당 단 0.033ms가 소요되며, 이는 전체 추론 시간의 0.25%에 불과합니다.
  • 속도 향상: CPU Trie 기반 방식 대비 948배, 기존 하드웨어 가속 이진 탐색(PPV Exact) 기법 대비 1,033배 빠릅니다.
  • 메모리 효율: 2,000만 개의 제약 조건을 처리하는 데 최대 약 1.5GB의 고대역폭 메모리 용량만 소비합니다.

5. YouTube 프로덕션 환경에서의 비즈니스 임팩트

STATIC은 연구실에 머물지 않고 수십억 명이 이용하는 실제 YouTube ‘홈 피드’ 추천에 배포되었습니다. 30억 파라미터 규모의 모델에 ‘최근 7일 이내 업로드’라는 제약 조건을 부여하여 A/B 테스트를 진행했습니다.

  • 모델은 제약 조건을 100% 완벽하게 준수하여 오래된 영상 추천을 원천 차단했습니다.
  • 7일 이내 신선한 영상 조회수가 +5.1% 증가했습니다.
  • 사용자 클릭률(CTR)이 +0.15%, 전략적 사용자 세그먼트의 만족도가 +0.15% 상승했습니다.

6. 결론 및 인사이트

STATIC은 수십 년간 사용해 온 Trie라는 고전적인 자료구조를 AI 가속기 시대의 문법인 ‘정적 희소 행렬’로 재해석하여 한계를 돌파했습니다.

단순히 AI 모델의 파라터를 키우는 것을 넘어, 하드웨어의 특성을 완벽히 이해하고 시스템 엔지니어링 단위에서 병목을 해결한 훌륭한 사례입니다. 실시간 대규모 LLM 서빙을 고민하는 개발자라면 반드시 참고해 볼 만한 기술적 이정표가 될 것입니다.


7. 부록: JAX/Flax 기반 핵심 코드 리뷰

논문에 수록된 실제 JAX 구현체를 살펴보면 STATIC의 철학이 코드 레벨에서 어떻게 적용되었는지 명확히 알 수 있습니다.

Point 1. 분기 없는 벡터화 추출 (VNTK 커널)

기존 Trie가 포인터를 따라 동적으로 탐색했다면, STATIC은 jnp.take를 활용해 정적인 크기(layer_max_branches)만큼의 후보를 한 번에 가져옵니다.

# Algorithm 2: Vectorized Node Transition Kernel (VNTK)
# 동적 슬라이싱을 jnp.take 연산 하나로 대체하여 정적 그래프를 유지
gathered = jnp.take(
    transition_matrix.data,
    starts[:, None] + offsets[None, :],
    axis=0,
    mode="fill",
    fill_value=0
)

Point 2. 마스킹을 통한 유효성 검증

자식 노드 개수(lens) 범위를 벗어나는 빈 슬롯들은 valid_mask를 씌운 뒤, jnp.where를 통해 확률값을 -inf(NEG_INF)로 덮어씌워 폐기합니다.

valid_mask = offsets[None, :] < lens[:, None]
# 유효하지 않은 경로는 -inf 처리하여 Beam Search에서 자연스럽게 탈락시킴
return jnp.where(valid_mask, candidate_lp, NEG_INF)

Point 3. 하이브리드 고밀도(Dense) 조회

디코딩 극초반(Step 1~2)에는 노드가 너무 많아 Sparse 행렬 연산이 오히려 비효율적입니다. 따라서 사전에 패킹해 둔 고밀도 마스크(11_dense_mask_packed)를 통해 즉각적으로 $O(1)$ 필터링을 수행합니다.

if step <= 1 and transition_matrix.11_dense_mask_packed is not None:
    mask = jnp.unpackbits(
        transition_matrix.11_dense_mask_packed[parents], axis=-1
    )
    return jnp.where(mask.astype(bool)[..., :vocab_size], log_probs, NEG_INF)

댓글남기기