|
|
||
|---|---|---|
| .agent/skills | ||
| .agents/skills/tdd | ||
| .claude/skills | ||
| docs | ||
| prompts | ||
| scripts | ||
| src/corridorkey_mlx | ||
| tests | ||
| .gitignore | ||
| .python-version | ||
| CLAUDE.md | ||
| main.py | ||
| pyproject.toml | ||
| README.md | ||
| skills-lock.json | ||
| uv.lock | ||
corridorkey-mlx
MLX inference port of CorridorKey for Apple Silicon.
Architecture
RGB image + coarse alpha hint (4ch)
│
▼
┌──────────────────┐
│ Hiera backbone │ (timm, features_only)
│ → 4 multiscale │
│ feature maps │
└──────────────────┘
│
┌────┴────┐
▼ ▼
┌───────┐ ┌───────┐
│ Alpha │ │ FG │
│ head │ │ head │
│ (1ch) │ │ (3ch) │
└───────┘ └───────┘
│ │
└────┬────┘
▼
┌──────────────────┐
│ CNN Refiner │ RGB + coarse preds (7ch)
│ → delta logits │ → sigmoid
└──────────────────┘
│
▼
final alpha + fg
Phased Roadmap
| Phase | Scope | Status |
|---|---|---|
| 1 | PyTorch reference harness + fixture dump | Done |
| 2 | MLX decoder/refiner blocks + parity tests | Done |
| 3 | Checkpoint conversion (PyTorch → MLX) | Done |
| 4 | Hiera backbone port | Done |
| 5 | Full model assembly + e2e parity | Done |
| 6 | Optimization + benchmarking | Done |
See prompts/ for detailed phase instructions.
Usage
Setup
uv sync --group dev
Download weights
Model weights (~400 MB) are distributed via GitHub Releases — not committed to git (binary blobs bloat history and exceed Git LFS free tiers).
CLI (quickest):
# download latest release to platform cache dir
uv run python -m corridorkey_mlx weights download
# specific tag
uv run python -m corridorkey_mlx weights download --tag v1.0.0
# override asset name
uv run python -m corridorkey_mlx weights download --asset corridorkey_mlx.safetensors
# force re-download
uv run python -m corridorkey_mlx weights download --force
# print local path (for scripting)
WEIGHTS=$(uv run python -m corridorkey_mlx weights download --print-path)
If installed as a package, a corridorkey-weights console script is also available:
corridorkey-weights download --tag v1.0.0 --print-path
Where weights are cached:
| Platform | Path |
|---|---|
| macOS | ~/Library/Caches/corridorkey_mlx/weights/<tag>/ |
| Linux | ~/.cache/corridorkey_mlx/weights/<tag>/ |
| Windows | %LOCALAPPDATA%\corridorkey_mlx\Cache\weights\<tag>\ |
Environment variable overrides:
| Variable | Purpose |
|---|---|
CORRIDORKEY_MLX_WEIGHTS_REPO |
owner/repo for a different GitHub repo |
CORRIDORKEY_MLX_WEIGHTS_TAG |
Default tag (instead of latest) |
CORRIDORKEY_MLX_WEIGHTS_ASSET |
Default asset filename |
GITHUB_TOKEN |
Auth token for higher API rate limits |
Python API:
from corridorkey_mlx.weights import download_weights
path = download_weights(tag="v1.0.0")
Publishing weights to a GitHub Release
Generate the checksum, then upload both files:
# generate sha256 sidecar
shasum -a 256 corridorkey_mlx.safetensors > corridorkey_mlx.safetensors.sha256
# create release and upload assets
gh release create v1.0.0 \
corridorkey_mlx.safetensors \
corridorkey_mlx.safetensors.sha256 \
--title "v1.0.0" --notes "Initial weights release"
The downloader verifies the SHA256 automatically. If no .sha256 sidecar or
SHA256SUMS file is found, verification is skipped with a warning.
Convert weights
Convert the PyTorch checkpoint to MLX safetensors (one-time):
uv run python scripts/convert_weights.py \
--checkpoint checkpoints/CorridorKey_v1.0.pth \
--output checkpoints/corridorkey_mlx.safetensors
Single-image inference
uv run python scripts/infer.py \
--image input.png \
--hint alpha_hint.png \
--output-dir output/
Outputs output/alpha.png (alpha matte) and output/foreground.png (foreground).
Options:
--checkpoint PATH— MLX safetensors file (default:checkpoints/corridorkey_mlx.safetensors)--img-size N— model input resolution (default: 512)--output-dir DIR— output directory (default:output/)
Python API
from corridorkey_mlx.inference.pipeline import load_model, infer_and_save
model = load_model("checkpoints/corridorkey_mlx.safetensors", img_size=512)
results = infer_and_save(model, "input.png", "alpha_hint.png", "output/")
Development
uv run pytest # tests
uv run ruff check . # lint
uv run ruff format . # format
uv run ty check # type check
For PyTorch reference work:
uv sync --group reference
Reference Fixtures
Phase 1 generates golden reference tensors from PyTorch for MLX parity testing.
Format: single reference/fixtures/golden.npz (numpy compressed archive)
Generate:
uv run --group reference python scripts/dump_pytorch_reference.py \
--checkpoint checkpoints/CorridorKey_v1.0.pth
Contents (all float32, NCHW, batch=1, img_size=512):
| Key | Shape | Description |
|---|---|---|
input |
(1, 4, 512, 512) | Random input (seed=42) |
encoder_feature_0 |
(1, 112, 128, 128) | Backbone stride-4 |
encoder_feature_1 |
(1, 224, 64, 64) | Backbone stride-8 |
encoder_feature_2 |
(1, 448, 32, 32) | Backbone stride-16 |
encoder_feature_3 |
(1, 896, 16, 16) | Backbone stride-32 |
alpha_logits |
(1, 1, 128, 128) | Alpha decoder output (H/4) |
fg_logits |
(1, 3, 128, 128) | FG decoder output (H/4) |
alpha_logits_up |
(1, 1, 512, 512) | Alpha logits upsampled |
fg_logits_up |
(1, 3, 512, 512) | FG logits upsampled |
alpha_coarse |
(1, 1, 512, 512) | sigmoid(alpha_logits_up) |
fg_coarse |
(1, 3, 512, 512) | sigmoid(fg_logits_up) |
delta_logits |
(1, 4, 512, 512) | Refiner output (10x scaled) |
alpha_final |
(1, 1, 512, 512) | Final alpha prediction |
fg_final |
(1, 3, 512, 512) | Final FG prediction |
Parity Results
End-to-end parity vs PyTorch reference (512×512, float32):
| Tensor | Max Abs Error | Mean Abs Error |
|---|---|---|
| alpha_logits | 8.8e-05 | 1.6e-05 |
| fg_logits | 1.5e-04 | 7.2e-06 |
| alpha_coarse | 9.7e-06 | 1.1e-06 |
| fg_coarse | 6.7e-06 | 1.1e-06 |
| delta_logits | 1.1e-04 | 4.3e-06 |
| alpha_final | 2.6e-05 | 8.7e-08 |
| fg_final | 9.5e-06 | 1.1e-06 |
Performance
Compiled inference
Use compile=True for fused execution on fixed-resolution inputs:
model = load_model("checkpoints/corridorkey_mlx.safetensors", img_size=512, compile=True)
The first call incurs a one-time compilation cost. Subsequent calls at the same
resolution run faster. Shapeless compilation (shapeless=True) is not recommended
due to shape-dependent reshapes in the Hiera backbone.
Benchmarking
uv run python scripts/bench_mlx.py
uv run python scripts/bench_mlx.py --resolutions 256 512 1024 --bench-runs 20
Reports eager vs compiled latency, warmup cost, and parity check per resolution.
Large images (tiled inference)
For images larger than the model's input resolution, use tiled inference with overlap blending:
from corridorkey_mlx.inference.tiling import tiled_inference
model = load_model("checkpoints/corridorkey_mlx.safetensors", img_size=512)
x = preprocess(rgb, alpha_hint) # full-resolution (1, H, W, 4)
result = tiled_inference(model, x, tile_size=512, overlap=64)
Recommended settings for Apple Silicon
| Setting | Value | Notes |
|---|---|---|
img_size |
512 | Good speed/quality balance |
compile |
True | ~1.5–2x faster after warmup |
tile_size |
512 | Match img_size for tiling |
overlap |
64 | Smooth blending at tile boundaries |
Comparing against PyTorch reference
uv run python scripts/compare_reference.py
Using as a CorridorKey backend
This repo can be consumed as a drop-in MLX backend by the main CorridorKey app.
Install (editable, from sibling checkout)
# from the main CorridorKey repo directory
uv pip install -e ../corridorkey-mlx
Engine API
from corridorkey_mlx import CorridorKeyMLXEngine
engine = CorridorKeyMLXEngine(
checkpoint_path="/abs/path/to/corridorkey_mlx.safetensors",
img_size=2048, # production (512 for dev)
use_refiner=True,
compile=True, # faster after first call
)
result = engine.process_frame(rgb_uint8, mask_uint8)
# result["alpha"] — (H, W) uint8 alpha matte
# result["fg"] — (H, W, 3) uint8 foreground
# result["comp"] — (H, W, 3) uint8 fg composited over black
# result["processed"] — (H, W, 3) uint8 (placeholder, same as fg)
Expected inputs
- image: numpy uint8
(H, W, 3)RGB. sRGB color space (standard). - mask: numpy uint8
(H, W)or(H, W, 1)grayscale alpha hint. - checkpoint:
.safetensorsformat, converted from PyTorch viascripts/convert_weights.py.
Inputs are resized internally to img_size for inference, then outputs are
resized back to the original input resolution.
Smoke test
uv run python scripts/smoke_engine.py \
--image input.png --hint hint.png \
--checkpoint checkpoints/corridorkey_mlx.safetensors \
--img-size 512
2048 smoke test
Validates full end-to-end inference at CorridorKey's native 2048 resolution.
Uses samples/sample.png + samples/hint.png by default; falls back to
synthetic inputs if samples are unavailable.
uv run python scripts/smoke_2048.py
With real images:
uv run python scripts/smoke_2048.py --image shot.png --hint hint.png
Reports timing, peak memory, output shapes, and value-range diagnostics. This is an execution check, not a 2048 parity validation.
To run the slow pytest version:
uv run pytest -m slow
Standalone scripts vs engine usage
Standalone (scripts/infer.py) |
Engine (CorridorKeyMLXEngine) |
|
|---|---|---|
| Input | file paths | numpy arrays |
| Output | saved PNGs | in-memory dict |
| Returns | alpha, foreground |
alpha, fg, comp, processed |
| Default img_size | 512 | 2048 |
| Use case | one-off CLI inference | app backend integration |
Stubs (not yet implemented)
despill_strength— accepted but ignored (warns once)auto_despeckle/despeckle_size— accepted but ignored (warns once)input_is_linear— accepted but no-op (model expects sRGB)
Python version
Requires Python >=3.11. Compatible with the main CorridorKey repo's 3.11 target.
Current Status
Phases 1–6 complete. Full model assembly with end-to-end parity verified. Optimization, benchmarking, and tiled inference available. Engine integration surface available for backend consumption.