--- license: apache-2.0 library_name: onnxruntime tags: - sleep-staging - knowledge-distillation - onnx - torchscript - edge-deployment - eeg - polysomnography datasets: - sleep-edf metrics: - accuracy pipeline_tag: tabular-classification model-index: - name: Conv1dStack_T2_a0.3 results: - task: type: tabular-classification name: Sleep Stage Classification dataset: type: sleep-edf name: Sleep-EDF metrics: - type: accuracy value: 0.752 name: Validation Accuracy --- # Conv1dStack_T2_a0.3 — Distilled Sleep Stage Classifier A tiny (103KB, 25,957 params) sleep stage classifier distilled from [SleepFM](https://arxiv.org/abs/2311.07919) for real-time edge deployment on NVIDIA Jetson TK1 and similar constrained devices. ## Model Details | Property | Value | |----------|-------| | Architecture | Conv1dStack | | Parameters | 25,957 | | Model size | 103.3 KB | | Distillation temperature | 2 | | Alpha (hard label weight) | 0.3 | | Validation accuracy | 75.2% | | Input shape | `(B, S, 128)` — pre-pooled embeddings | | Output | 5-class logits (Wake, REM, N1, N2, N3) | | ONNX opset | 11 | ### Student Architecture Config ```yaml Conv1dStack: hidden_channels: 32 kernel_size: 5 ``` ### Distillation Setup - **Teacher**: SleepFM (SleepEventLSTMClassifier) — biLSTM, 128-dim embeddings - **Sweep**: 3 temperatures × 3 alphas × 3 architectures = 27 experiments - **Training**: 50 epochs, AdamW lr=0.001, early stopping (patience=10) - **Data**: Sleep-EDF (5 train / 1 val / 1 test subjects) ### Target Hardware | Spec | Value | |------|-------| | Device | NVIDIA Jetson TK1 | | CUDA cores | 192 | | RAM | 2 GB | | Compute capability | 3.2 | ## Usage ### ONNX Runtime ```python import numpy as np import onnxruntime as ort session = ort.InferenceSession("Conv1dStack_T2_a0.3.onnx") # Input: pre-pooled embeddings (batch, seq_len, 128) embeddings = np.random.randn(1, 120, 128).astype(np.float32) logits = session.run(None, {"input": embeddings})[0] predicted_stages = np.argmax(logits, axis=-1) # Stage mapping: 0=Wake, 1=REM, 2=N1, 3=N2, 4=N3 print(predicted_stages) ``` ### TorchScript ```python import torch model = torch.jit.load("Conv1dStack_T2_a0.3.pt") embeddings = torch.randn(1, 120, 128) logits = model(embeddings) predicted_stages = logits.argmax(dim=-1) ``` ## Files | File | Format | Description | |------|--------|-------------| | `Conv1dStack_T2_a0.3.onnx` | ONNX (opset 11) | For ONNX Runtime / TensorRT | | `Conv1dStack_T2_a0.3.pt` | TorchScript | For PyTorch / LibTorch on-device | ## Limitations - Trained on Sleep-EDF only (7 subjects) — may not generalize to other PSG datasets - Expects pre-pooled 128-dim embeddings from SleepFM's encoder, not raw EEG - No per-class metrics reported (overall accuracy only) - Distilled from a single teacher checkpoint ## Citation ```bibtex @misc{circadia-distill-2026, title={Distilled Sleep Stage Classifier for Edge Deployment}, year={2026}, url={https://github.com/circadia} } ```