metadata license: apache-2.0
library_name: onnxruntime
tags:
- sleep-staging
- knowledge-distillation
- onnx
- torchscript
- edge-deployment
- eeg
- polysomnography
datasets:
- sleep-edf
metrics:
- accuracy
pipeline_tag: tabular-classification
model-index:
- name: Conv1dStack_T2_a0.3
results:
- task:
type: tabular-classification
name: Sleep Stage Classification
dataset:
type: sleep-edf
name: Sleep-EDF
metrics:
- type: accuracy
value: 0.752
name: Validation Accuracy
Conv1dStack_T2_a0.3 — Distilled Sleep Stage Classifier
A tiny (103KB, 25,957 params) sleep stage classifier distilled from
SleepFM for real-time edge deployment on
NVIDIA Jetson TK1 and similar constrained devices.
Model Details
Property
Value
Architecture
Conv1dStack
Parameters
25,957
Model size
103.3 KB
Distillation temperature
2
Alpha (hard label weight)
0.3
Validation accuracy
75.2%
Input shape
(B, S, 128) — pre-pooled embeddings
Output
5-class logits (Wake, REM, N1, N2, N3)
ONNX opset
11
Student Architecture Config
Conv1dStack:
hidden_channels: 32
kernel_size: 5
Distillation Setup
Teacher : SleepFM (SleepEventLSTMClassifier) — biLSTM, 128-dim embeddings
Sweep : 3 temperatures × 3 alphas × 3 architectures = 27 experiments
Training : 50 epochs, AdamW lr=0.001, early stopping (patience=10)
Data : Sleep-EDF (5 train / 1 val / 1 test subjects)
Target Hardware
Spec
Value
Device
NVIDIA Jetson TK1
CUDA cores
192
RAM
2 GB
Compute capability
3.2
Usage
ONNX Runtime
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession("Conv1dStack_T2_a0.3.onnx" )
embeddings = np.random.randn(1 , 120 , 128 ).astype(np.float32)
logits = session.run(None , {"input" : embeddings})[0 ]
predicted_stages = np.argmax(logits, axis=-1 )
print (predicted_stages)
TorchScript
import torch
model = torch.jit.load("Conv1dStack_T2_a0.3.pt" )
embeddings = torch.randn(1 , 120 , 128 )
logits = model(embeddings)
predicted_stages = logits.argmax(dim=-1 )
Files
File
Format
Description
Conv1dStack_T2_a0.3.onnx
ONNX (opset 11)
For ONNX Runtime / TensorRT
Conv1dStack_T2_a0.3.pt
TorchScript
For PyTorch / LibTorch on-device
Limitations
Trained on Sleep-EDF only (7 subjects) — may not generalize to other PSG datasets
Expects pre-pooled 128-dim embeddings from SleepFM's encoder, not raw EEG
No per-class metrics reported (overall accuracy only)
Distilled from a single teacher checkpoint
Citation
@misc{circadia-distill-2026,
title={Distilled Sleep Stage Classifier for Edge Deployment},
year={2026},
url={https://github.com/circadia}
}