add-reference-tests by flashinfer-ai
Add pytest tests to validate reference implementations in flashinfer_trace against FlashInfer or SGLang ground truth. Use when validating kernel definitions, adding tests for new op_types, or verifying reference implementations are correct.
Content & Writing
119 Stars
17 Forks
Updated Jan 18, 2026, 02:27 PM
Why Use This
This skill provides specialized capabilities for flashinfer-ai's codebase.
Use Cases
- Developing new features in the flashinfer-ai repository
- Refactoring existing code to follow flashinfer-ai standards
- Understanding and working with flashinfer-ai's codebase structure
Install Guide
2 steps- 1
Skip this step if Ananke is already installed.
- 2
Skill Snapshot
Auto scan of skill assets. Informational only.
Valid SKILL.md
Checks against SKILL.md specification
Source & Community
Skill Stats
SKILL.md 852 Lines
Total Files 1
Total Size 0 B
License NOASSERTION
---
name: add-reference-tests
description: Add pytest tests to validate reference implementations in flashinfer_trace against FlashInfer or SGLang ground truth. Use when validating kernel definitions, adding tests for new op_types, or verifying reference implementations are correct.
---
# Add Reference Tests
Add tests to validate reference implementations in `./flashinfer_trace/`. Ground truth is sourced from FlashInfer repository or SGLang when FlashInfer doesn't have the implementation.
## Description
This skill creates test cases under `./flashinfer_trace/tests/references/` to validate that reference implementations in Definition JSON files produce correct outputs. The ground truth comes from:
1. **FlashInfer repository** (preferred): Official optimized GPU kernels in `tmp/flashinfer/`
2. **SGLang repository** (fallback): When FlashInfer doesn't have the kernel, use `tmp/sglang/`
## Usage
```bash
# Test all definitions of a specific op_type
/add-reference-tests --op-type mla_paged
/add-reference-tests --op-type moe
/add-reference-tests --op-type gqa_paged
/add-reference-tests --op-type rmsnorm
# Test a specific definition
/add-reference-tests --definition-name mla_paged_decode_h16_ckv512_kpe64_ps1
# Test all definitions in the definitions directory
/add-reference-tests --all
# Test with custom tolerance
/add-reference-tests --definition-name rmsnorm_h4096 --tolerance 1e-4
```
## Parameters
- `definition_name` (optional): Specific definition to test (e.g., "mla_paged_decode_h16_ckv512_kpe64_ps1")
- `op_type` (optional): Test all definitions of a specific op_type (e.g., "mla_paged", "moe", "rmsnorm")
- `all` (optional): Test all definitions in the definitions directory
- `test_sizes` (optional): List of test sizes ["small", "medium", "large"] (default: ["small", "medium"])
- `tolerance` (optional): Numerical tolerance for comparison (default: 1e-3 for fp16, 1e-5 for fp32)
## Prerequisites
Run `/clone-repos` first to set up the `tmp/` directory with SGLang and FlashInfer (the `flashinfer_trace/` directory is already part of this repository).
## What This Skill Does
### Phase 1: Definition Discovery
1. **Load Target Definitions**:
- If `definition_name` specified: load single definition
- If `op_type` specified: load all definitions matching op_type from `flashinfer_trace/definitions/{op_type}/`
- If `all`: scan all definitions
2. **Check Existing Tests**:
- Scan `flashinfer_trace/tests/references/` for existing test files
- Skip definitions that already have tests (unless force=true)
3. **Parse Definition Schema**:
- Extract axes (const/var), inputs, outputs
- Identify required shapes and dtypes
- Parse reference implementation code
### Phase 2: Ground Truth Discovery
For each definition, locate ground truth implementation using this priority order:
#### For Model Constants: HuggingFace + SGLang (Required)
**Note**: See [extract-kernel-definitions](../extract-kernel-definitions/SKILL.md#for-model-constants-huggingface--sglang-required) for detailed guidance on sourcing model constants from HuggingFace and SGLang.
#### For Ground Truth Execution: FlashInfer API (Primary)
- **When**: FlashInfer has the kernel implementation (MOST kernels)
- **Location**: `tmp/flashinfer/python/flashinfer/`
- **Use for**: Running optimized GPU kernel as ground truth
- **Examples**:
```python
import flashinfer
# GQA decode
flashinfer.BatchDecodeWithPagedKVCacheWrapper(...)
# GQA prefill
flashinfer.BatchPrefillWithPagedKVCacheWrapper(...)
# MLA
flashinfer.mla.BatchMLAPagedAttentionWrapper(...)
# RMSNorm
flashinfer.norm.rmsnorm(...)
```
- **Important**: FlashInfer is the PRIMARY ground truth for correctness validation
#### For Ground Truth Execution: SGLang (Fallback ONLY)
- **When**: FlashInfer does NOT have the kernel (e.g., some MoE variants)
- **Location**: `tmp/sglang/python/sglang/srt/layers/`
- **Use for**: Ground truth when FlashInfer unavailable
- **Examples**:
```
layers/moe/fused_moe.py # MoE when FlashInfer MoE unavailable
```
- **Important**: ONLY use SGLang as ground truth when FlashInfer doesn't support the kernel
#### Ground Truth Source Mapping
| Op Type | Ground Truth Source | FlashInfer API | Fallback (if FlashInfer unavailable) |
|---------|---------------------|----------------|--------------------------------------|
| `rmsnorm` | FlashInfer | `flashinfer.norm.rmsnorm` | N/A (FlashInfer has it) |
| `fused_add_rmsnorm` | FlashInfer | `flashinfer.norm.fused_add_rmsnorm` | N/A (FlashInfer has it) |
| `gqa_paged` | FlashInfer | `flashinfer.BatchDecodeWithPagedKVCacheWrapper`, `flashinfer.BatchPrefillWithPagedKVCacheWrapper` | N/A |
| `gqa_ragged` | FlashInfer | `flashinfer.BatchPrefillWithRaggedKVCacheWrapper` | N/A |
| `mla_paged` | FlashInfer | `flashinfer.mla.BatchMLAPagedAttentionWrapper` | N/A (FlashInfer has it) |
| `moe` | SGLang (fallback) | N/A (FlashInfer MoE may not cover all variants) | `sglang/layers/moe/fused_moe.py` |
| `gemm` | PyTorch | N/A | `torch.nn.functional.linear` |
| `sampling` | FlashInfer | `flashinfer.sampling.*` | N/A |
#### Reference `run()` Function Sources
**Note**: For detailed guidance on sourcing reference implementations, see the [extract-kernel-definitions](../extract-kernel-definitions/SKILL.md) skill's "Reference Implementation Sources" section.
**Quick Reference**:
- **Primary**: FlashInfer unit tests at `tmp/flashinfer/tests/` (e.g., `test_batch_decode.py`, `test_norm.py`)
- **Fallback**: SGLang vanilla implementations at `tmp/sglang/python/sglang/srt/layers/` (only when FlashInfer unavailable)
### Phase 3: Test Generation
For each definition, generate test file following the standards below.
## Test File Standards
### File Structure
Each test file should follow this structure:
```python
import json
import math
from pathlib import Path
import numpy as np
import pytest
import torch
# Ground truth imports (with availability checks)
try:
import flashinfer
from flashinfer.xxx import some_kernel
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
# Module-level constants from definition
HIDDEN_SIZE = 7168
NUM_EXPERTS = 256
# ... other constants
TRACE_ROOT = Path(__file__).resolve().parents[2]
WORKLOAD_JSONL_PATH = TRACE_ROOT / "workloads" / "op_type" / "definition_name.jsonl"
@torch.no_grad()
def run(...):
"""Reference implementation matching the definition."""
# Check constants
assert hidden_size == HIDDEN_SIZE
...
def generate_random_inputs(..., device="cuda"):
"""Generate random inputs for testing."""
...
return {...}
def test_correctness(..., atol=1e-2, rtol=5e-2):
"""Test correctness of reference implementation against ground truth."""
...
def main():
"""Run comprehensive tests."""
...
if __name__ == "__main__":
main()
```
### Coding Style Patterns
1. **Constants at Module Level**: Define model-specific constants at the top
```python
# DeepSeek V3/R1 MoE constants
HIDDEN_SIZE = 7168
INTERMEDIATE_SIZE = 2048
NUM_EXPERTS_GLOBAL = 256
NUM_LOCAL_EXPERTS = 32 # EP=8
```
2. **Use `@torch.no_grad()` Decorator**: For all reference implementations and test functions
3. **Input Generator Function**: Separate function `generate_random_inputs(...)` that returns a dict
4. **Test Function Pattern**:
```python
def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2):
"""Test correctness of reference implementation against ground truth."""
print(f"\n{'='*60}")
print(f"Testing {description}: {params}")
print(f"{'='*60}")
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("WARNING: CUDA not available, skipping test")
return
# Generate inputs
inputs = generate_random_inputs(...)
# Run reference implementation
print("\nRunning reference implementation...")
ref_output = run(**inputs)
# Run ground truth
print("Running ground truth (FlashInfer/SGLang)...")
gt_output = ground_truth_fn(**inputs)
# Compare outputs
print("\nComparing outputs...")
# ... detailed comparison
```
### Correctness Checking Patterns
#### Standard Tolerance Check (FP16/BF16)
```python
# Convert to float32 for comparison
ref_f32 = ref_output.float()
gt_f32 = gt_output.float()
# Compute detailed error metrics
abs_diff = torch.abs(ref_f32 - gt_f32)
rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8)
max_abs_diff = abs_diff.max().item()
max_rel_diff = rel_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
mean_rel_diff = rel_diff.mean().item()
print(f"\nOutput tensor comparison:")
print(f"Max absolute difference: {max_abs_diff:.6e}")
print(f"Max relative difference: {max_rel_diff:.6e}")
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
print(f"Mean relative difference: {mean_rel_diff:.6e}")
# Cosine similarity and MSE
cos_sim = torch.nn.functional.cosine_similarity(
ref_f32.flatten(), gt_f32.flatten(), dim=0
).item()
mse = torch.mean((ref_f32 - gt_f32) ** 2).item()
print(f"Cosine similarity: {cos_sim:.6f}")
print(f"MSE: {mse:.6e}")
# Check tolerance
all_close = torch.allclose(ref_f32, gt_f32, atol=atol, rtol=rtol)
if all_close:
print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})")
else:
print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})")
```
#### Hit Ratio Check (for FP8/Quantized Kernels)
For quantized kernels with higher variance, use hit ratio instead of strict allclose:
```python
# Check what percentage of elements pass the tolerance check
left = (ref_f32 - gt_f32).abs()
right = atol + rtol * gt_f32.abs()
ok = left <= right
hit_ratio = ok.float().mean().item()
print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {percent * 100:.2f}%)")
return hit_ratio >= percent # e.g., 85%
```
#### Error Location Debugging
When tests fail, show top error locations:
```python
if not all_close:
flat = abs_diff.flatten()
k = min(5, flat.numel())
topv, topi = torch.topk(flat, k)
print(f"\nTop-{k} absolute error locations:")
for rank in range(k):
idx = topi[rank].item()
# Convert flat index to multi-dimensional
# ... compute indices
print(f" [{indices}]: ref={ref_val:.6e}, gt={gt_val:.6e}, diff={topv[rank].item():.6e}")
```
### Tolerance Guidelines
| Data Type | atol | rtol | Notes |
|-----------|------|------|-------|
| float32 | 1e-5 | 1e-5 | Strictest |
| float16 | 1e-3 | 1e-3 | Standard |
| bfloat16 | 8e-3 | 1e-2 | 0.8% abs, 1% rel |
| float8_e4m3fn | 1e-1 | 2e-1 | Use hit ratio ≥85% |
| nvfp4 | 1e-1 | 2e-1 | Use hit ratio ≥85% |
## Multi-Ground-Truth Testing
### Pattern for Multiple Ground Truths
When both FlashInfer and SGLang implementations are available, test against both:
```python
# Ground truth imports
try:
from flashinfer.xxx import flashinfer_kernel
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
try:
from sglang.srt.layers.xxx import sglang_kernel
SGLANG_AVAILABLE = True
except ImportError:
SGLANG_AVAILABLE = False
def test_correctness_vs_flashinfer(...):
"""Test reference against FlashInfer ground truth."""
if not FLASHINFER_AVAILABLE:
pytest.skip("FlashInfer not available")
ref_output = run(**inputs)
fi_output = flashinfer_kernel(**inputs)
assert_close(ref_output, fi_output, atol=atol, rtol=rtol)
def test_correctness_vs_sglang(...):
"""Test reference against SGLang ground truth."""
if not SGLANG_AVAILABLE:
pytest.skip("SGLang not available")
ref_output = run(**inputs)
sg_output = sglang_kernel(**inputs)
assert_close(ref_output, sg_output, atol=atol, rtol=rtol)
def test_ground_truths_match(...):
"""Test that FlashInfer and SGLang produce consistent results."""
if not (FLASHINFER_AVAILABLE and SGLANG_AVAILABLE):
pytest.skip("Both ground truths not available")
fi_output = flashinfer_kernel(**inputs)
sg_output = sglang_kernel(**inputs)
assert_close(fi_output, sg_output, atol=atol, rtol=rtol)
```
### Comprehensive Test Runner
```python
def main():
"""Run comprehensive tests against all available ground truths."""
print("Testing Reference Implementation")
# Test configurations
test_configs = [
(1, 16), # Small
(4, 32), # Medium
(8, 64), # Large
]
results = {"flashinfer": [], "sglang": []}
for config in test_configs:
if FLASHINFER_AVAILABLE:
try:
ok = test_correctness_vs_flashinfer(*config)
results["flashinfer"].append(ok)
except Exception as e:
print(f"FlashInfer test failed: {e}")
results["flashinfer"].append(False)
if SGLANG_AVAILABLE:
try:
ok = test_correctness_vs_sglang(*config)
results["sglang"].append(ok)
except Exception as e:
print(f"SGLang test failed: {e}")
results["sglang"].append(False)
# Summary
print(f"\n{'='*60}")
print("Summary:")
for source, passed_list in results.items():
if passed_list:
print(f" {source}: {sum(passed_list)}/{len(passed_list)} tests passed")
print(f"{'='*60}")
```
## Standard Test Generation
For each definition, generate test file:
1. **Create Test Class** with:
- Fixture for loading definition
- Fixture for compiling reference implementation
- Fixture for ground truth function
2. **Generate Test Inputs**:
- Parse definition schema for input shapes and dtypes
- Generate random tensors matching specifications
- Handle both constant and variable axes
3. **Create Test Methods**:
- `test_output_shape`: Verify output shapes match definition
- `test_output_dtype`: Verify output dtypes match definition
- `test_numerical_correctness`: Compare reference vs ground truth
- `test_determinism`: Verify reproducible results
### Phase 4: Test Cases
Generate multiple test cases with varying sizes:
```python
# Small: Quick smoke tests
SMALL_SIZES = {
"batch_size": [1, 2],
"seq_len": [1, 16],
"num_pages": [1, 4],
}
# Medium: Standard tests
MEDIUM_SIZES = {
"batch_size": [4, 8, 16],
"seq_len": [64, 128, 256],
"num_pages": [16, 32, 64],
}
```
### Phase 5: Write Test Files
Output to `flashinfer_trace/tests/references/`
## Output Structure
```
flashinfer_trace/tests/references/
├── conftest.py # Shared fixtures and utilities
├── test_rmsnorm.py # RMSNorm tests
├── test_gqa_paged.py # GQA paged tests
├── test_mla_paged.py # MLA paged tests
├── test_moe.py # MoE tests
└── test_gemm.py # GEMM tests
```
## Test File Template
```python
"""Tests for {definition_name} reference implementation."""
import math
from pathlib import Path
import numpy as np
import torch
# Ground truth imports with availability checks
try:
import flashinfer
from flashinfer.xxx import some_kernel
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
try:
from sglang.srt.layers.xxx import sglang_kernel
SGLANG_AVAILABLE = True
except ImportError:
SGLANG_AVAILABLE = False
# Module-level constants (from definition)
NUM_QO_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
PAGE_SIZE = 1
TRACE_ROOT = Path(__file__).resolve().parents[2]
@torch.no_grad()
def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale):
"""Reference implementation matching the definition schema."""
batch_size, num_qo_heads, head_dim = q.shape
_, page_size, num_kv_heads, _ = k_cache.shape
# Check constants
assert num_qo_heads == NUM_QO_HEADS
assert num_kv_heads == NUM_KV_HEADS
assert head_dim == HEAD_DIM
assert page_size == PAGE_SIZE
device = q.device
# Reference computation (pure PyTorch)
output = torch.zeros((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device)
lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device)
# ... detailed step-by-step computation ...
return output, lse
def generate_random_inputs(
batch_size,
max_seq_len,
num_attention_heads=NUM_QO_HEADS,
num_key_value_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
page_size=PAGE_SIZE,
device="cuda",
):
"""Generate random inputs for testing."""
# Generate tensors matching definition schema
q = torch.randn(batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device)
# ... generate other inputs ...
return {
"q": q,
"k_cache": k_cache,
"v_cache": v_cache,
"kv_indptr": kv_indptr,
"kv_indices": kv_indices,
"sm_scale": sm_scale,
}
def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2):
"""Test correctness of reference implementation against FlashInfer."""
print(f"\n{'='*60}")
print(f"Testing batch_size={batch_size}, max_seq_len={max_seq_len}")
print(f"{'='*60}")
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("WARNING: CUDA not available, skipping test")
return
# Generate inputs
inputs = generate_random_inputs(batch_size, max_seq_len, device=device)
print(f"Generated sequences with shapes: q={inputs['q'].shape}")
# Run reference implementation
print("\nRunning reference implementation...")
ref_output, ref_lse = run(**{k: v for k, v in inputs.items() if k != 'extra_keys'})
# Run FlashInfer ground truth
print("Running FlashInfer...")
# ... FlashInfer setup and execution ...
# Compare outputs
print("\nComparing outputs...")
# Convert to float32 for comparison
ref_f32 = ref_output.float()
fi_f32 = fi_output.float()
# Compute errors
abs_diff = torch.abs(ref_f32 - fi_f32)
rel_diff = abs_diff / (torch.abs(fi_f32) + 1e-8)
max_abs_diff = abs_diff.max().item()
max_rel_diff = rel_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
mean_rel_diff = rel_diff.mean().item()
print(f"\nOutput tensor comparison:")
print(f"Max absolute difference: {max_abs_diff:.6e}")
print(f"Max relative difference: {max_rel_diff:.6e}")
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
print(f"Mean relative difference: {mean_rel_diff:.6e}")
# Cosine similarity and MSE
cos_sim = torch.nn.functional.cosine_similarity(
ref_f32.flatten(), fi_f32.flatten(), dim=0
).item()
mse = torch.mean((ref_f32 - fi_f32) ** 2).item()
print(f"Cosine similarity: {cos_sim:.6f}")
print(f"MSE: {mse:.6e}")
# Check if outputs match within tolerance
all_close = torch.allclose(ref_f32, fi_f32, atol=atol, rtol=rtol)
if all_close:
print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})")
else:
print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})")
# Show top error locations for debugging
flat = abs_diff.flatten()
k = min(5, flat.numel())
topv, topi = torch.topk(flat, k)
print(f"\nTop-{k} absolute error locations:")
for rank in range(k):
idx = topi[rank].item()
print(f" idx={idx}: ref={ref_f32.flatten()[idx].item():.6e}, "
f"fi={fi_f32.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}")
return all_close
def main():
"""Run comprehensive tests."""
print("Testing Reference Implementation vs FlashInfer")
# Test configurations
test_configs = [
(1, 16), # Single batch
(4, 32), # Small batch
(8, 64), # Medium batch
(16, 128), # Large batch
]
passed = 0
total = len(test_configs)
for batch_size, max_seq_len in test_configs:
try:
if test_correctness(batch_size, max_seq_len):
passed += 1
except Exception as e:
print(f"✗ Test failed with exception: {str(e)}")
import traceback
traceback.print_exc()
print(f"\n{'='*60}")
print(f"Summary: {passed}/{total} tests passed")
print(f"{'='*60}")
if passed == total:
print("✓ All tests passed!")
else:
print(f"✗ {total - passed} tests failed")
if __name__ == "__main__":
main()
```
## conftest.py Template
```python
"""Shared test fixtures for reference implementation tests."""
import json
import math
import pytest
import torch
from pathlib import Path
DEFINITIONS_DIR = Path(__file__).parent.parent / "definitions"
WORKLOADS_DIR = Path(__file__).parent.parent / "workloads"
@pytest.fixture
def device():
"""Get test device (CUDA if available)."""
return "cuda" if torch.cuda.is_available() else "cpu"
def load_definition(name: str) -> dict:
"""Load a definition JSON by name."""
for op_dir in DEFINITIONS_DIR.iterdir():
if op_dir.is_dir():
def_file = op_dir / f"{name}.json"
if def_file.exists():
with open(def_file) as f:
return json.load(f)
raise FileNotFoundError(f"Definition {name} not found")
def compile_reference(reference_code: str):
"""Compile reference implementation to callable function."""
namespace = {"torch": torch, "math": math, "np": __import__("numpy")}
exec(reference_code, namespace)
return namespace["run"]
def assert_close(actual, expected, rtol=1e-3, atol=1e-3):
"""Assert tensors are close within tolerance."""
if isinstance(actual, tuple):
for a, e in zip(actual, expected):
assert_close(a, e, rtol, atol)
elif isinstance(actual, dict):
for k in actual:
assert_close(actual[k], expected[k], rtol, atol)
else:
torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol)
def compute_error_metrics(ref, gt, name="output"):
"""Compute and print detailed error metrics."""
ref_f32 = ref.float()
gt_f32 = gt.float()
abs_diff = torch.abs(ref_f32 - gt_f32)
rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8)
print(f"\n{name} comparison:")
print(f" Max absolute difference: {abs_diff.max().item():.6e}")
print(f" Max relative difference: {rel_diff.max().item():.6e}")
print(f" Mean absolute difference: {abs_diff.mean().item():.6e}")
print(f" Mean relative difference: {rel_diff.mean().item():.6e}")
cos_sim = torch.nn.functional.cosine_similarity(
ref_f32.flatten(), gt_f32.flatten(), dim=0
).item()
mse = torch.mean((ref_f32 - gt_f32) ** 2).item()
print(f" Cosine similarity: {cos_sim:.6f}")
print(f" MSE: {mse:.6e}")
return abs_diff, rel_diff
def check_hit_ratio(ref, gt, atol, rtol, required_percent=0.85):
"""Check if hit ratio meets threshold (for quantized kernels)."""
ref_f32 = ref.float()
gt_f32 = gt.float()
left = (ref_f32 - gt_f32).abs()
right = atol + rtol * gt_f32.abs()
ok = left <= right
hit_ratio = ok.float().mean().item()
print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {required_percent * 100:.2f}%)")
return hit_ratio >= required_percent
```
## Implementation Steps
When executing this skill:
1. **Identify definitions to test**:
```bash
ls flashinfer_trace/definitions/{op_type}/
```
2. **Check for existing tests**:
```bash
ls flashinfer_trace/tests/references/
```
3. **For each definition**:
- Read the definition JSON
- Identify ground truth source (FlashInfer or SGLang)
- Generate test class with appropriate fixtures
- Generate test methods for shape, dtype, and numerical correctness
4. **Create test file**:
```bash
# Create tests directory if needed
mkdir -p flashinfer_trace/tests/references/
```
5. **Write test file**:
- If testing multiple definitions of same op_type, combine into one file
- Each definition gets its own test class
6. **Create/update conftest.py** with shared fixtures
## Running Tests
After generating tests, run from the project root:
```bash
# Run all reference tests
pytest flashinfer_trace/tests/references/ -v
# Run specific test file
pytest flashinfer_trace/tests/references/test_mla_paged.py -v
# Run with GPU
pytest flashinfer_trace/tests/references/ -v --device cuda
# Run with verbose output
pytest flashinfer_trace/tests/references/ -v -s
```
## Error Handling
### Ground Truth Not Available
- **Error**: Ground truth implementation not found
- **Handling**: Follow priority order:
1. Check SGLang vanilla implementation
2. Check SGLang FlashInfer API integration
3. Check FlashInfer API directly
4. If none available, mark test as skip with reason
### Definition Parse Error
- **Error**: Invalid definition JSON
- **Handling**: Report validation errors, skip test generation
### Shape Mismatch
- **Error**: Reference output shape doesn't match definition
- **Handling**: Create failing test, flag for investigation
### Numerical Divergence
- **Error**: Reference differs from ground truth beyond tolerance
- **Handling**: Create failing test with detailed diff report
## Integration with Other Skills
```bash
# Complete workflow
/clone-repos
# Extract definitions
/extract-kernel-definitions --model-name deepseek_v3
# Add tests for new definitions
/add-reference-tests --op-type mla_paged
/add-reference-tests --op-type moe
# Run tests from project root
pytest flashinfer_trace/tests/references/ -v
```
## Notes
- Tests run on GPU by default; CPU fallback for CI environments
- Tolerance varies by dtype: looser for fp16 (1e-3), stricter for fp32 (1e-5)
- Some kernels may not have FlashInfer ground truth yet
- Test parametrization covers common batch/sequence sizes
- Tests marked with `@pytest.mark.slow` for large sizes
## Maintaining This Document
Update this file when changing ground truth sources, test patterns, tolerance values, or adding new op_types.
## See Also
- [clone-repos](../clone-repos/SKILL.md)
- [extract-kernel-definitions](../extract-kernel-definitions/SKILL.md)
Name Size