extract-kernel-definitions by flashinfer-ai
Extract kernel schemas and definitions from SGLang model implementations with deduplication. Use when adding a new model, extracting GPU kernels (MLA, MoE, GQA, RMSNorm, GEMM), or generating Definition JSON files for flashinfer_trace.
Content & Writing
119 Stars
17 Forks
Updated Jan 18, 2026, 02:27 PM
Why Use This
This skill provides specialized capabilities for flashinfer-ai's codebase.
Use Cases
- Developing new features in the flashinfer-ai repository
- Refactoring existing code to follow flashinfer-ai standards
- Understanding and working with flashinfer-ai's codebase structure
Install Guide
2 steps- 1
Skip this step if Ananke is already installed.
- 2
Skill Snapshot
Auto scan of skill assets. Informational only.
Valid SKILL.md
Checks against SKILL.md specification
Source & Community
Skill Stats
SKILL.md 706 Lines
Total Files 1
Total Size 0 B
License NOASSERTION
---
name: extract-kernel-definitions
description: Extract kernel schemas and definitions from SGLang model implementations with deduplication. Use when adding a new model, extracting GPU kernels (MLA, MoE, GQA, RMSNorm, GEMM), or generating Definition JSON files for flashinfer_trace.
---
# Extract Kernel Definitions
Extract kernel schemas and definitions from SGLang model implementations, with deduplication, and add them to `./flashinfer_trace/` with vanilla Python reference implementations. Uses sgl-cookbook serving configurations to generate multiple kernel definitions for different TP/EP settings.
## Description
This skill analyzes SGLang model implementations to extract the complete set of GPU kernels used during inference. It identifies kernel types (MLA, MOE, GQA, RMSNorm, GEMM), extracts their parameters, generates Definition JSON schemas with Python reference implementations, and handles deduplication across multiple models.
**Key Feature**: Uses sgl-cookbook repository to find recommended serving configurations (tensor parallel, expert parallel flags) and generates multiple kernel definitions for different parallelism settings. For example, Qwen3-Next has TP=2 and TP=4 configs, resulting in separate GDN definitions with different head counts.
## Usage
```bash
# Extract kernels from DeepSeek V3 model (MLA + MoE)
/extract-kernel-definitions --model-name deepseek_v3
# Extract from Llama (GQA)
/extract-kernel-definitions --model-name llama
# Extract from multiple models with automatic deduplication
/extract-kernel-definitions --model-name deepseek_v3
/extract-kernel-definitions --model-name llama
/extract-kernel-definitions --model-name qwen2_moe
# Extract only prefill or decode kernels
/extract-kernel-definitions --model-name llama --execution-modes prefill
/extract-kernel-definitions --model-name llama --execution-modes decode
```
## Parameters
- `model_name` (required): Model name to extract kernels from (e.g., "deepseek_v3", "llama", "qwen2_moe")
- `execution_modes` (optional): Inference execution modes to analyze (default: ["prefill", "decode"])
- `include_quantized` (optional): Include quantized kernel variants (default: true)
- `deduplicate` (optional): Check for and skip existing definitions (default: true)
## Prerequisites
Run `/clone-repos` first to set up the `tmp/` directory with SGLang, FlashInfer, and sgl-cookbook (the `flashinfer_trace/` directory is already part of this repository).
## What This Skill Does
### Phase 1: Model Analysis
1. **Locate Model Implementation**:
- Search `tmp/sglang/python/sglang/srt/models/{model_name}.py`
- Identify model class (e.g., `DeepseekV3ForCausalLM`, `LlamaForCausalLM`)
- Parse model architecture from config
2. **Find Serving Configurations (NEW)**:
- Search `tmp/sgl-cookbook/data/models/generated/v0.5.6/` for model-specific YAML configs
- Identify all recommended TP (tensor parallel) and EP (expert parallel) settings across different hardware platforms
- Parse YAML files to extract `tp` and `ep` values from hardware configurations
- **Example**: Qwen3-Next (qwen3next.yaml) has:
- H100: tp=4 (requires definitions with heads split by 4)
- H200/B200: tp=2 (requires definitions with heads split by 2)
- **Example**: DeepSeek V3.2 (deepseek.yaml) has:
- All platforms: tp=8 (base config)
- high-throughput-ep: tp=8, ep=8 (affects num_local_experts)
3. **Identify Layer Components**:
- Attention mechanism (GQA, MHA, MLA, GDN)
- MLP/FFN structure (Dense, MoE)
- Normalization layers (RMSNorm, LayerNorm)
- Embedding and output projections
4. **Extract Execution Paths**:
- Prefill path (batch processing, variable sequence length)
- Decode path (single token, paged KV cache)
### Phase 2: Kernel Extraction with Parallelism Variants
For each layer component AND each serving configuration (TP/EP setting), extract:
#### Attention Kernels (Affected by TP)
- **GQA**: `gqa_paged_decode`, `gqa_paged_prefill`, `gqa_ragged_prefill`
- **TP Impact**: `num_qo_heads` = original_heads / TP, `num_kv_heads` = original_kv_heads / TP
- **MLA**: `mla_paged_decode`, `mla_paged_prefill`
- **TP Impact**: `num_qo_heads` = original_heads / TP
- **GDN**: `gdn_decode`, `gdn_prefill`
- **TP Impact**: `num_q_heads` = original_q_heads / TP, `num_v_heads` = original_v_heads / TP
- **Example**: Qwen3-Next with TP=2 → `gdn_decode_qk8_v16_d128_k_last`
- **Example**: Qwen3-Next with TP=4 → `gdn_decode_qk4_v8_d128_k_last`
- **Parameters**: num_heads, num_kv_heads, head_dim, page_size, ckv_dim, kpe_dim
#### MoE Kernels (Affected by EP)
- Expert routing and selection
- Expert execution (FP8 quantized, block-scaled)
- **EP Impact**: `num_local_experts` = num_experts / EP
- **Example**: DeepSeek V3 has 256 experts
- EP=1 → `num_local_experts=256`
- EP=2 → `num_local_experts=128`
- EP=4 → `num_local_experts=64`
- EP=8 → `num_local_experts=32`
- **Parameters**: num_experts (global), num_local_experts (per device), topk, hidden_size, intermediate_size, group_size
#### Normalization Kernels (NOT affected by TP/EP — skip tp/ep tags)
- `rmsnorm_h{hidden_size}`
- `fused_add_rmsnorm_h{hidden_size}`
- **Parameters**: hidden_size, epsilon
- **Note**: Same definition across all TP/EP configs. Do NOT add `tp:N` or `ep:N` tags. Refer to sgl-cookbook to confirm the serving config, but the kernel definition itself is parallelism-agnostic.
#### GEMM Kernels (NOT affected by TP/EP — skip tp/ep tags)
- QKV projection, O projection
- Gate/Up projection, Down projection
- **Parameters**: M (variable), N (output dim), K (input dim)
- **Note**: Shape changes handled at runtime, definition remains constant. Do NOT add `tp:N` or `ep:N` tags.
#### Sampling Kernels (NOT affected by TP/EP — skip tp/ep tags)
- `top_k_sampling_from_probs`, `top_p_sampling_from_probs`, etc.
- **Parameters**: vocab_size, batch_size
- **Note**: Sampling operates on per-token logits after all-reduce; the vocab dimension is never split. Do NOT add `tp:N` or `ep:N` tags.
### Phase 3: Deduplication
1. **Load Existing Definitions**:
- Scan `flashinfer_trace/definitions/` for existing JSONs
- Build index of definition names and signatures
2. **Compare Extracted Kernels**:
- Check if kernel with same name exists
- Verify parameter compatibility
- Skip existing definitions, report what was skipped
3. **Handle Shared Kernels**:
- Kernels like `rmsnorm_h4096` may be used by multiple models
- Only create once, add model tags to existing definitions
### Phase 4: Definition Generation with Parallelism Tags
For each new kernel AND each TP/EP configuration, generate a Definition JSON following the standards below.
**IMPORTANT**: When generating multiple definitions for different TP/EP configs:
1. Create separate JSON files with parameters reflecting the parallelism split
2. Add `tp:N` / `ep:N` tags **only for kernels whose input sizes are affected by parallelism** (attention/GDN kernels for TP; MoE kernels for EP). **Skip these tags for parallelism-agnostic kernels** (rmsnorm, gemm, sampling) — instead, refer to sgl-cookbook to confirm the serving config.
3. Update constant axes to reflect post-split values
4. Include TP/EP info in the description field
## Definition JSON Standards
### Naming Convention
Follow the pattern: `{op_type}_{variant}_{key_params}`
**Parameter Abbreviations:**
- `h` = num_heads or hidden_size (context-dependent)
- `kv` = num_kv_heads
- `d` = head_dim
- `ps` = page_size
- `ckv` = compressed_kv_dim
- `kpe` = key_positional_encoding_dim
- `e` = num_experts
- `i` = intermediate_size
- `topk` = top_k_experts
- `ng` = n_group
- `kg` = topk_group
- `v` = vocab_size
**Examples by op_type:**
- Attention: `gqa_paged_decode_h32_kv8_d128_ps1`, `mla_paged_prefill_h16_ckv512_kpe64_ps1`
- MoE: `moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048`
- Normalization: `rmsnorm_h4096`, `fused_add_rmsnorm_h7168`
- Sampling: `top_k_sampling_from_probs_v129280`
### Tag Patterns
Tags follow the pattern `{category}:{value}`:
| Category | Values | Description |
|----------|--------|-------------|
| `status` | `verified`, `unverified` | Whether reference implementation is validated |
| `stage` | `decode`, `prefill` | Inference execution mode |
| `model` | `deepseek-v3`, `deepseek-r1`, `llama-3.1-8b`, etc. | Associated model(s) |
| `tp` | `1`, `2`, `4`, `8`, … | Tensor parallel size this definition was captured at (affects head counts for attention/GDN kernels). **Omit for parallelism-agnostic kernels** (rmsnorm, gemm, sampling) — refer to sgl-cookbook to confirm TP but do not add this tag. |
| `ep` | `1`, `2`, `4`, `8`, … | Expert parallel size this definition was captured at (affects `num_local_experts` for MoE kernels). **Omit for parallelism-agnostic kernels** (rmsnorm, gemm, sampling). |
| `quantization` | `float8_e4m3fn`, `nvfp4`, `int8`, `int4` | Quantization format |
| `routing` | `pre-computed`, `on-the-fly` | For MoE routing type |
| `fi_api` | `flashinfer.norm.rmsnorm`, `flashinfer.mla.BatchMLAPagedAttentionWrapper`, etc. | FlashInfer Python API name for this kernel (omit if no FlashInfer API exists) |
**FlashInfer API Name Mapping by op_type:**
| op_type | FlashInfer API | Notes |
|---------|----------------|-------|
| `rmsnorm` | `flashinfer.norm.rmsnorm` | |
| `rmsnorm` (fused_add) | `flashinfer.norm.fused_add_rmsnorm` | |
| `gqa_paged` (decode) | `flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper` | |
| `gqa_paged` (prefill) | `flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` | |
| `gqa_ragged` | `flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper` | |
| `mla_paged` | `flashinfer.mla.BatchMLAPagedAttentionWrapper` | |
| `dsa_paged` | `flashinfer.sparse.BlockSparseAttentionWrapper` | Verify in `tmp/flashinfer/python/flashinfer/` |
| `gdn` (prefill) | `flashinfer.gdn.chunk_gated_delta_rule` | Defined in `gdn_prefill.py` |
| `gdn` (decode) | `flashinfer.gdn.gated_delta_rule_decode` | Defined in `gdn_decode.py` |
| `gdn` (mtp/multi-token-predict) | `flashinfer.gdn.gated_delta_rule_mtp` | Defined in `gdn_decode.py` |
| `moe` (fp8 block scale) | `flashinfer.fused_moe.trtllm_fp8_block_scale_moe` | Defined in `flashinfer/fused_moe/core.py` |
| `moe` (other variants) | N/A — Not Supported here yet | FlashInfer MoE coverage varies |
| `gemm` | N/A — use `torch.nn.functional.linear` | No dedicated FlashInfer API |
| `sampling` | `flashinfer.sampling.top_k_sampling_from_probs`, etc. | Match specific sampling variant |
**Example tags array:**
```json
"tags": [
"stage:decode",
"status:verified",
"model:deepseek-v3",
"model:deepseek-r1",
"quantization:float8_e4m3fn",
"fi_api:flashinfer.mla.BatchMLAPagedAttentionWrapper",
"tp:8"
]
```
### Axes Structure
```json
"axes": {
"batch_size": {
"type": "var",
"description": "Batch size (number of sequences)"
},
"num_qo_heads": {
"type": "const",
"value": 16,
"description": "Number of query heads after tensor parallel split (128/8=16)."
}
}
```
**Rules:**
- Variable axes (`type: "var"`): runtime dimensions like batch_size, seq_len, num_pages
- Constant axes (`type: "const"`): model-specific values with `value` field
- Always include `description` for complex or model-specific axes
### Constraints Field
Add constraints for input validation:
```json
"constraints": [
"len_indptr == batch_size + 1",
"num_kv_indices == kv_indptr[-1].item()"
]
```
### Inputs/Outputs Structure
```json
"inputs": {
"tensor_name": {
"shape": ["axis1", "axis2", "axis3"],
"dtype": "bfloat16",
"description": "Description of the tensor"
},
"scalar_input": {
"shape": null,
"dtype": "float32",
"description": "Scalar parameter"
}
}
```
**dtype values:** `float32`, `float16`, `bfloat16`, `float8_e4m3fn`, `float8_e5m2`, `int32`, `int64`
### Reference Implementation
1. **Generate Name**: Follow naming convention `{op_type}_{variant}_{params}`
- Example: `mla_paged_decode_h16_ckv512_kpe64_ps1`
- Example: `moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048`
2. **Define Axes**:
- Constant axes: fixed at compile time (e.g., hidden_size, num_heads)
- Variable axes: determined at runtime (e.g., batch_size, seq_len)
3. **Specify Inputs/Outputs**:
- Tensor shapes using axis names
- Data types (float16, bfloat16, float8_e4m3fn, etc.)
4. **Write Reference Implementation**:
- **ALWAYS prioritize FlashInfer tests** (`tmp/flashinfer/tests/`) for reference implementations
- Look for vanilla PyTorch reference functions (e.g., `ref_attention()`, `ref_rmsnorm()`)
- Only use SGLang vanilla implementation when FlashInfer doesn't have the kernel
- Plain PyTorch implementation with step-by-step computation
- See "Reference Implementation Sources" section below
### Phase 5: Save Definitions
Output to `flashinfer_trace/definitions/{op_type}/{definition_name}.json`
## Output Format
### Definition JSON Example (MLA Decode)
```json
{
"name": "mla_paged_decode_h16_ckv512_kpe64_ps1",
"description": "Batched Multi-head Latent Attention decode with a paged KV cache. Captured from DeepSeek-V3 with tensor parallel size 8.",
"op_type": "mla_paged",
"tags": [
"stage:decode",
"status:verified",
"model:deepseek-v3",
"model:deepseek-r1",
"fi_api:flashinfer.mla.BatchMLAPagedAttentionWrapper",
"tp:8"
],
"axes": {
"batch_size": { "type": "var" },
"num_qo_heads": {
"type": "const",
"value": 16,
"description": "Number of query heads after tensor parallel split (128/8=16)."
},
"head_dim_ckv": { "type": "const", "value": 512 },
"head_dim_kpe": { "type": "const", "value": 64 },
"page_size": { "type": "const", "value": 1 },
"num_pages": { "type": "var", "description": "Total number of allocated pages in the KV cache." },
"len_indptr": { "type": "var", "description": "Length of kv_indptr array." },
"num_kv_indices": { "type": "var", "description": "Total number of KV page indices." }
},
"constraints": [
"len_indptr == batch_size + 1",
"num_kv_indices == kv_indptr[-1].item()"
],
"inputs": {
"q_nope": {
"shape": ["batch_size", "num_qo_heads", "head_dim_ckv"],
"dtype": "bfloat16",
"description": "Query tensor without positional encoding component."
},
"q_pe": {
"shape": ["batch_size", "num_qo_heads", "head_dim_kpe"],
"dtype": "bfloat16",
"description": "Query positional encoding component."
},
"ckv_cache": {
"shape": ["num_pages", "page_size", "head_dim_ckv"],
"dtype": "bfloat16",
"description": "Compressed key-value cache."
},
"kpe_cache": {
"shape": ["num_pages", "page_size", "head_dim_kpe"],
"dtype": "bfloat16",
"description": "Key positional encoding cache."
},
"kv_indptr": {
"shape": ["len_indptr"],
"dtype": "int32",
"description": "KV page offsets for each sequence."
},
"kv_indices": {
"shape": ["num_kv_indices"],
"dtype": "int32",
"description": "Page indices for KV cache lookups."
},
"sm_scale": {
"shape": null,
"dtype": "float32",
"description": "Softmax scale. Default is (1/sqrt(128 + 64) = 1/sqrt(192))."
}
},
"outputs": {
"output": { "shape": ["batch_size", "num_qo_heads", "head_dim_ckv"], "dtype": "bfloat16" },
"lse": {
"shape": ["batch_size", "num_qo_heads"],
"dtype": "float32",
"description": "The 2-based log-sum-exp of attention logits."
}
},
"reference": "import math\nimport torch\n\[email protected]_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, sm_scale):\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n ..."
}
```
### Definition JSON Example (MoE)
```json
{
"name": "moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048",
"description": "FP8 block scale MoE operation. Routing and two grouped-GEMM included.",
"op_type": "moe",
"tags": [
"status:verified",
"model:deepseek-v3",
"model:deepseek-r1",
"quantization:float8_e4m3fn",
"ep:8"
],
"axes": {
"seq_len": { "type": "var", "description": "Sequence length (number of tokens)" },
"num_experts": {
"type": "const",
"value": 256,
"description": "Total number of experts."
},
"num_local_experts": {
"type": "const",
"value": 32,
"description": "Number of local experts with EP size 8."
},
"hidden_size": {
"type": "const",
"value": 7168,
"description": "Hidden dimension size."
},
"intermediate_size": {
"type": "const",
"value": 2048,
"description": "MoE intermediate layer size."
}
},
"inputs": {
"routing_logits": {
"shape": ["seq_len", "num_experts"],
"dtype": "float32",
"description": "Tensor of routing logits for expert selection"
},
"hidden_states": {
"shape": ["seq_len", "hidden_size"],
"dtype": "float8_e4m3fn",
"description": "Input hidden states tensor (FP8 quantized)"
},
"local_expert_offset": {
"shape": null,
"dtype": "int32",
"description": "Offset of local experts in global expert space."
},
"routed_scaling_factor": {
"shape": null,
"dtype": "float32",
"description": "Scaling factor for routing weights."
}
},
"outputs": {
"output": {
"shape": ["seq_len", "hidden_size"],
"dtype": "bfloat16",
"description": "Final MoE output tensor"
}
},
"reference": "import torch\n\[email protected]_grad()\ndef run(routing_logits, hidden_states, ...):\n # Check constants\n assert H == 7168, 'hidden_size must be 7168'\n ..."
}
```
> **Note**: The MoE FP8 block-scale variant has a FlashInfer API (`flashinfer.fused_moe.trtllm_fp8_block_scale_moe`); update the `tags` array accordingly. Other MoE variants without FlashInfer support should omit the `fi_api:` tag and use SGLang as ground truth. In general, omit `fi_api:` whenever no FlashInfer API covers the kernel.
## Implementation Steps
When executing this skill:
1. **Locate model file**:
```bash
ls tmp/sglang/python/sglang/srt/models/ | grep -i {model_name}
```
2. **Find sgl-cookbook configs for the model**:
```bash
# Find the matching YAML config file
find tmp/sgl-cookbook/data/models/generated/v0.5.6/ -name "*{model_name}*.yaml"
# Example: for qwen3-next
cat tmp/sgl-cookbook/data/models/generated/v0.5.6/qwen3next.yaml
```
3. **Parse YAML to extract TP/EP configurations**:
- Read YAML file and parse hardware configurations
- Extract all unique `tp` values across all hardware platforms and configurations
- Extract all unique `ep` values (if model uses MoE)
- **Example parsing result for Qwen3-Next**:
```
TP configs: [2, 4]
EP configs: [null] (no MoE)
```
- **Example parsing result for DeepSeek-V3.2**:
```
TP configs: [8]
EP configs: [null, 8]
```
4. **Read model implementation**:
- Parse the Python file
- Identify forward() methods
- Extract kernel calls (attention, MoE, norm, GEMM)
5. **Extract kernel parameters from model config**:
- Look for config class (e.g., `DeepseekV3Config`)
- Extract: num_hidden_layers, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, etc.
6. **Check existing definitions**:
```bash
ls flashinfer_trace/definitions/*/
```
7. **Generate definition JSONs for each TP/EP config**:
- For each unique TP value, calculate split head counts
- For each unique EP value, calculate local expert counts
- Create directory if needed: `mkdir -p flashinfer_trace/definitions/{op_type}/`
- Write JSON file with reference implementation
- **Example**: For Qwen3-Next GDN kernel with original q_heads=16, v_heads=32:
- TP=2: Create `gdn_decode_qk8_v16_d128_k_last.json` (16/2=8, 32/2=16)
- TP=4: Create `gdn_decode_qk4_v8_d128_k_last.json` (16/4=4, 32/4=8)
8. **Report results**:
- List new definitions created (grouped by TP/EP config)
- List existing definitions skipped (deduplication)
- List kernels shared across models
## SGLang Code Patterns to Look For
### Attention Kernel Calls
```python
# GQA decode (paged)
flashinfer.batch_decode_with_paged_kv_cache(...)
# MLA decode
from sglang.srt.layers.attention.mla_decode import mla_decode_attention
```
### MoE Kernel Calls
```python
# DeepSeek MoE
from sglang.srt.layers.moe.fused_moe import fused_moe
fused_moe(hidden_states, w1, w2, w3, topk_weights, topk_ids, ...)
```
### Normalization
```python
# RMSNorm
from sglang.srt.layers.layernorm import RMSNorm
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
```
## Ground Truth Hierarchy
When extracting kernel definitions, use different sources for different purposes:
### For Model Constants: HuggingFace + SGLang (Required)
**IMPORTANT**: Always refer to the HuggingFace model page first (`https://huggingface.co/{org}/{model-name}`) for authoritative model constants when creating definitions for new kernels. Download `config.json` from the model repo for values like `hidden_size`, `num_experts`, `topk`, etc.
Use SGLang (`tmp/sglang/python/sglang/srt/models/{model_name}.py`) for runtime-specific constants like `page_size` and tensor parallel configurations (e.g. `num_attention_heads`, `num_local_experts`).
### For Reference `run()` Implementation: FlashInfer Unit Tests (PRIMARY - ALWAYS CHECK FIRST)
- **When**: FlashInfer has a kernel implementation and corresponding unit test **(CHECK THIS FIRST)**
- **Location**: `tmp/flashinfer/tests/`
- **Why**: FlashInfer tests contain vanilla PyTorch implementations that serve as ground truth
- **Examples**:
```
tests/test_batch_decode.py # GQA decode reference
tests/test_batch_prefill.py # GQA prefill reference
tests/test_norm.py # RMSNorm reference
tests/test_mla.py # MLA reference
```
- **Pattern**: Look for functions like `ref_attention()`, `ref_rmsnorm()`, `ref_mla()`, etc.
- **Use for**: Writing the `reference` field in Definition JSON
- **IMPORTANT**: Always search FlashInfer tests directory FIRST before falling back to SGLang
### For Reference `run()` Implementation: SGLang Vanilla (FALLBACK ONLY)
- **When**: FlashInfer does NOT have the kernel (e.g., some MoE variants, custom ops)
- **Location**: `tmp/sglang/python/sglang/srt/layers/`
- **Examples**:
```
layers/moe/fused_moe.py # MoE vanilla forward (when FlashInfer MoE not available)
layers/attention/triton_ops/ # Custom attention implementations
```
- **Important**: Only use SGLang vanilla implementation when FlashInfer doesn't have the kernel
- **How to verify**: Check `tmp/flashinfer/tests/` first to confirm kernel is not available
### For API Signature: FlashInfer Python API
- **When**: Determining input/output tensor specifications
- **Location**: `tmp/flashinfer/python/flashinfer/`
- **Examples**:
- `flashinfer.attention.batch_decode_with_paged_kv_cache`
- `flashinfer.norm.rmsnorm`
- `flashinfer.mla.BatchMLAPagedAttentionWrapper`
- **Use for**: Input/output shape and dtype specifications in Definition JSON
## Reference Implementation Sources
The `reference` field in Definition JSON contains a `run()` function. **Always prioritize FlashInfer unit tests over SGLang implementations.**
### Primary Source: FlashInfer Unit Tests (REQUIRED - CHECK THIS FIRST)
- **Location**: `tmp/flashinfer/tests/`
- **Why**: FlashInfer tests contain ground-truth vanilla PyTorch implementations
- **How to find**: Search for reference functions in test files
```bash
# Search for reference implementations
grep -r "def ref_" tmp/flashinfer/tests/
grep -r "def reference" tmp/flashinfer/tests/
```
- **Kernel type to test file mapping**:
| Kernel Type | Test File | Reference Function |
|-------------|-----------|-------------------|
| GQA decode | `test_batch_decode.py` | `ref_attention()` or inline reference |
| GQA prefill | `test_batch_prefill.py` | `ref_attention()` or inline reference |
| MLA | `test_mla.py` | `ref_mla()` or inline reference |
| RMSNorm | `test_norm.py` | `ref_rmsnorm()` or `ref_fused_add_rmsnorm()` |
| Sampling | `test_sampling.py` | Reference sampling implementations |
- **Important**: The reference `run()` should match FlashInfer's test implementation exactly, with constant values updated to match SGLang model config
### Fallback Source: SGLang Vanilla Implementation (ONLY when FlashInfer unavailable)
- **When to use**: ONLY when FlashInfer does NOT have a unit test for this kernel
- **Location**: `tmp/sglang/python/sglang/srt/layers/`
- **Common cases requiring SGLang fallback**:
- Custom MoE implementations: `layers/moe/fused_moe.py`
- Model-specific attention variants not in FlashInfer
- Quantized kernels not yet supported by FlashInfer
- **How to check if FlashInfer has the kernel**:
```bash
# Check if FlashInfer has a test for this kernel type
ls tmp/flashinfer/tests/test_*.py | xargs grep -l "your_kernel_name"
# Check if FlashInfer has the API
grep -r "def your_kernel" tmp/flashinfer/python/flashinfer/
```
### Reference Implementation Guidelines
1. **Pure PyTorch**: Use only `torch` operations, no external kernels
2. **Step-by-step**: Break down computation into clear steps
3. **Match signatures**: Input/output names must match definition schema
4. **Include all outputs**: Return all outputs specified in definition (e.g., both `output` and `lse` for attention)
Example reference implementation pattern:
```python
import torch
import math
def run(q, k, v, ...):
# Step 1: Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) * sm_scale
# Step 2: Apply softmax
attn_weights = torch.softmax(scores, dim=-1)
# Step 3: Compute output
output = torch.matmul(attn_weights, v)
# Step 4: Compute log-sum-exp (if needed)
lse = torch.logsumexp(scores, dim=-1)
return output, lse
```
## Kernel Type to Model Mapping
| Model | Attention | MLP | Normalization |
|-------|-----------|-----|---------------|
| DeepSeek V3/R1 | MLA | MoE | RMSNorm |
| Llama 3.x | GQA | Dense | RMSNorm |
| Qwen2 MoE | GQA | MoE | RMSNorm |
| Mixtral | GQA | MoE | RMSNorm |
## Error Handling
### Model Not Found
- **Error**: Model file doesn't exist in SGLang
- **Handling**: List available models, suggest closest match
### Unsupported Kernel Type
- **Error**: Unknown kernel pattern in model
- **Handling**: Log warning, skip kernel, report for manual review
### Duplicate with Conflict
- **Error**: Definition exists with different parameters
- **Handling**: Create new versioned definition, flag for review
## Integration with Other Skills
```bash
# Complete workflow
/clone-repos
# Extract from multiple models
/extract-kernel-definitions --model-name deepseek_v3
/extract-kernel-definitions --model-name llama
/extract-kernel-definitions --model-name qwen2_moe
# Add tests for new definitions
/add-reference-tests --op-type mla_paged
/add-reference-tests --op-type moe
```
## Notes
- Each kernel may have multiple variants for different execution modes (prefill/decode)
- Quantized kernels are treated as separate definitions
- Reference implementations prioritize clarity over performance
- Model tags enable tracing which models use each kernel
- Deduplication is essential for maintaining a clean dataset
## Maintaining This Document
Update this file when adding new op_types, changing Definition JSON schema, or modifying extraction patterns.
## See Also
- [clone-repos](../clone-repos/SKILL.md)
- [add-reference-tests](../add-reference-tests/SKILL.md)
Name Size