Architecture¶
Core Components¶
┌─────────────────────────────────────────────────────────────┐
│ CLI (main.rs) / Library API (lib.rs) / Python API (PyO3) │
└──────────────┬──────────────────────────────────────────────┘
│
┌──────────┴──────────┬──────────────┬─────────────────┐
▼ ▼ ▼ ▼
┌────────┐ ┌──────────────┐ ┌──────────┐ ┌──────────────┐
│Loader │────▶│ Validator │──▶│ Context │───▶│ Backend │
│(JSON) │ │(graph.rs) │ │(selects) │ │ Selection │
└────────┘ └──────────────┘ └────┬─────┘ └──────┬───────┘
│ │
▼ ▼
┌──────────┐ ┌──────────────┐
│ Builder │ │ Converter │
│(backend- │ │ (Runtime) │
│agnostic) │ │ │
└────┬─────┘ └──────┬───────┘
│ │
▼ ▼
┌─────────────┐ ┌────────────────┐
│ MLGraph │ │ ONNX / CoreML │
│(immutable) │ │ Execution │
└─────────────┘ └────────────────┘
Key Principles¶
1. Backend-Agnostic Graph Representation¶
builder.build()creates an immutable, platform-independentGraphInfostructure- Contains operands, operations, inputs, outputs, and constant data
- No backend-specific artifacts at this stage
2. Runtime Backend Selection (WebNN Spec-Compliant)¶
Following the W3C WebNN Device Selection Explainer:
- Backend selection happens at context creation via
acceleratedandpower_preferencehints accelerated=False→ ONNX Runtime CPUaccelerated=True+power="high-performance"→ GPU preferred (ONNX or CoreML)accelerated=True+power="low-power"→ NPU preferred (CoreML Neural Engine on Apple Silicon)- Platform autonomously selects actual device based on availability and runtime conditions
- Selection logic in
PyMLContext::select_backend()
3. MLTensor Management¶
Following the W3C WebNN MLTensor Explainer:
- Explicit tensor management with descriptor flags (readable, writable, exportableToGPU)
destroy()method for explicit resource cleanupdispatch()for async execution with MLTensor inputs/outputs- Permission enforcement on read/write operations
4. Lazy Backend Conversion¶
- Backend-specific conversion happens during
compute(), notbuild() compute()routes to appropriate backend method:compute_onnx()for ONNX Runtimecompute_coreml()for CoreMLcompute_fallback()when no backend available- Same graph can be executed on different backends via different contexts
5. Rust-First Architecture¶
- All core functionality in pure Rust (validation, conversion, execution)
- Python bindings are thin wrappers exposing Rust functionality
- Rust library usable independently without Python
- Design principle: "Rust is the implementation, Python is the interface"
Shape Inference¶
Shape inference is the process of automatically computing output tensor shapes of neural network operations based on their input shapes and operation parameters, without executing the operation.
Why Shape Inference Matters¶
Shape inference enables:
- Early validation - Catch shape mismatches at build time, not runtime
- Memory allocation - Backend runtimes know output buffer sizes before execution
- Graph optimization - Enables static analysis and optimization passes
- Self-describing graphs - Graphs are fully annotated and backend-agnostic
How It Works¶
Each WebNN operation has a shape inference function in src/shape_inference.rs that computes output shapes. Shape inference happens during graph building, before any backend selection or execution.
Binary Operations (add, mul, div, etc.): - Use NumPy-style broadcasting rules - Two dimensions are compatible if equal or one is 1 - Output dimension is the maximum of the two
Matrix Multiplication:
// Simple 2D: [M, K] @ [K, N] → [M, N]
infer_matmul_shape([2, 3], [3, 4]) → [2, 4]
// Batched: [batch, M, K] @ [batch, K, N] → [batch, M, N]
infer_matmul_shape([5, 2, 3], [5, 3, 4]) → [5, 2, 4]
// Validates inner dimensions match (K must equal)
infer_matmul_shape([2, 3], [4, 5]) → Error: 3 != 4
Convolution (conv2d): - Takes input shape, filter shape, strides, padding, dilations - Computes spatial output dimensions:
output_h = floor((input_h + pad_top + pad_bottom - dilation_h * (kernel_h - 1) - 1) / stride_h + 1)
output_w = floor((input_w + pad_left + pad_right - dilation_w * (kernel_w - 1) - 1) / stride_w + 1)
Reshape:
// Validates element count is preserved
validate_reshape([2, 3, 4], [6, 4]) → OK (24 elements in both)
validate_reshape([2, 3, 4], [5, 5]) → Error (24 != 25 elements)
Pooling Operations: - Similar to convolution but without filters - Computes output spatial dimensions based on window size, strides, padding - Handles both average and max pooling - Global pooling reduces spatial dimensions to 1x1
Integration with Graph Builder¶
Shape inference is called automatically during graph construction:
# Python API example
x = builder.input("x", [2, 3], "float32") # Shape: [2, 3]
y = builder.input("y", [3, 4], "float32") # Shape: [3, 4]
z = builder.matmul(x, y) # Shape: [2, 4] (inferred)
output = builder.relu(z) # Shape: [2, 4] (preserved)
When you call builder.matmul(x, y), the implementation:
1. Calls infer_matmul_shape([2, 3], [3, 4]) from src/shape_inference.rs
2. Gets result [2, 4]
3. Creates operand descriptor with inferred shape
4. Stores operation in graph with validated inputs/outputs
This creates a fully-annotated, backend-agnostic graph that can be: - Validated for correctness - Visualized with Graphviz - Converted to ONNX, CoreML, or other formats - Executed on different backends without re-inference
Implementation Status¶
All 85 WebNN operations have shape inference implemented (100% coverage). Each operation includes:
- Shape inference function in src/shape_inference.rs
- Comprehensive validation (dimension compatibility, parameter constraints)
- Unit tests covering typical cases and edge cases
- Error messages with context for debugging
File Organization¶
src/
├── lib.rs # Public Rust API exports
├── main.rs # CLI entry point
├── graph.rs # Core data structures (backend-agnostic)
├── error.rs # Error types
├── validator.rs # Graph validation
├── loader.rs # JSON loading
├── graphviz.rs # DOT export
├── protos.rs # Protobuf module setup
├── converters/
│ ├── mod.rs # Registry and trait
│ ├── onnx.rs # ONNX converter
│ └── coreml.rs # CoreML converter
├── executors/
│ ├── mod.rs # Conditional compilation
│ ├── onnx.rs # ONNX runtime
│ └── coreml.rs # CoreML runtime
└── python/ # Python bindings (PyO3)
├── mod.rs # Python module definition
├── context.rs # ML and MLContext classes (backend selection)
├── graph_builder.rs # MLGraphBuilder class
├── graph.rs # MLGraph class
├── operand.rs # MLOperand class
└── tensor.rs # MLTensor class
python/webnn/ # Python package
├── __init__.py # Package exports (AsyncMLContext)
└── __init__.pyi # Type stubs
tests/
├── test_python_api.py # Python API tests (320+ tests)
├── test_wpt_conformance.py # WPT spec compliance tests
└── test_integration.py # Integration tests
examples/
├── python_simple.py # Basic Python example
├── python_matmul.py # Matrix multiplication
├── mobilenetv2_complete.py # Complete pretrained MobileNetV2
├── text_generation_gpt.py # Transformer with attention
└── train_text_model.py # Model training script
Design Patterns¶
Registry Pattern (Converters)¶
ConverterRegistrymanages converters dynamically- Trait objects:
Box<dyn GraphConverter + Send + Sync> - Extensible without modifying core code
Builder Pattern (Graph Construction)¶
MLGraphBuilderprovides fluent API for graph construction- Incremental construction of complex structures
- Used in ONNX and CoreML converters
Validation Pipeline¶
- Immutable graph input
- Stateful validator with progressive checks
- Comprehensive artifacts returned for downstream use
Conditional Compilation¶
#[cfg(target_os = "macos")]for platform-specific code#[cfg(feature = "...")]for optional features- Graceful degradation on unsupported platforms
Technical Decisions¶
- WebNN Spec Compliance: Follows W3C WebNN Device Selection and MLTensor explainers
- Protobuf for Interop: Native format for ONNX and CoreML
- Compile-time Codegen: Protobufs compiled at build time
- Feature Flags: Optional runtimes to minimize dependencies
- Objective-C FFI: Direct CoreML access on macOS
- Zero-copy where possible:
Bytestype for efficiency - Registry Pattern: Pluggable converters without core changes
Platform Support¶
- Validation & Conversion: Cross-platform (Linux, macOS, Windows)
- ONNX Execution: Cross-platform with
onnx-runtimefeature (CPU/GPU) - CoreML Execution: macOS only with
coreml-runtimefeature (GPU/Neural Engine) - Neural Engine: macOS with Apple Silicon (via CoreML)
- Python Bindings: Cross-platform with
pythonfeature (Python 3.11+)
Implementation Status¶
85 WebNN operations fully implemented across all backends:
- Shape Inference: 85/85 (100%)
- Python API: 85/85 (100%)
- ONNX Backend: 85/85 (100%)
- CoreML MLProgram: 85/85 (100%)
See implementation-status.md for complete details.