mazesmazes commited on
Commit
0e60196
·
verified ·
1 Parent(s): bbe4853

Update custom model files, README, and requirements

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -1
  2. README.md +263 -78
  3. asr_pipeline.py +164 -54
.gitattributes CHANGED
@@ -1,4 +1,3 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
- tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
README.md CHANGED
@@ -1,82 +1,267 @@
1
  ---
2
- library_name: transformers
 
 
 
 
 
 
 
 
3
  tags:
4
- - generated_from_trainer
5
- model-index:
6
- - name: tiny-audio
7
- results: []
 
 
8
  ---
9
 
10
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
11
- should probably proofread and complete it, then remove this comment. -->
12
-
13
- # tiny-audio
14
-
15
- This model is a fine-tuned version of [](https://huggingface.co/) on the None dataset.
16
- It achieves the following results on the evaluation set:
17
- - Loss: 0.4587
18
-
19
- ## Model description
20
-
21
- More information needed
22
-
23
- ## Intended uses & limitations
24
-
25
- More information needed
26
-
27
- ## Training and evaluation data
28
-
29
- More information needed
30
-
31
- ## Training procedure
32
-
33
- ### Training hyperparameters
34
-
35
- The following hyperparameters were used during training:
36
- - learning_rate: 0.001
37
- - train_batch_size: 14
38
- - eval_batch_size: 14
39
- - seed: 42
40
- - gradient_accumulation_steps: 4
41
- - total_train_batch_size: 56
42
- - optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
43
- - lr_scheduler_type: polynomial
44
- - lr_scheduler_warmup_steps: 500
45
- - num_epochs: 1
46
- - label_smoothing_factor: 0.1
47
-
48
- ### Training results
49
-
50
- | Training Loss | Epoch | Step | Validation Loss |
51
- |:-------------:|:------:|:-----:|:---------------:|
52
- | 2.1737 | 0.0418 | 1000 | 0.4878 |
53
- | 2.1091 | 0.0836 | 2000 | 0.4777 |
54
- | 2.0988 | 0.1254 | 3000 | 0.4728 |
55
- | 2.0590 | 0.1672 | 4000 | 0.4705 |
56
- | 2.0484 | 0.2090 | 5000 | 0.4689 |
57
- | 2.0637 | 0.2508 | 6000 | 0.4670 |
58
- | 2.0505 | 0.2926 | 7000 | 0.4659 |
59
- | 2.0550 | 0.3344 | 8000 | 0.4650 |
60
- | 2.0516 | 0.3762 | 9000 | 0.4641 |
61
- | 2.0530 | 0.4180 | 10000 | 0.4634 |
62
- | 2.0301 | 0.4598 | 11000 | 0.4628 |
63
- | 2.0608 | 0.5016 | 12000 | 0.4623 |
64
- | 2.0428 | 0.5434 | 13000 | 0.4621 |
65
- | 2.0248 | 0.5852 | 14000 | 0.4620 |
66
- | 2.0525 | 0.6270 | 15000 | 0.4612 |
67
- | 2.0281 | 0.6688 | 16000 | 0.4609 |
68
- | 2.0338 | 0.7106 | 17000 | 0.4600 |
69
- | 2.0492 | 0.7524 | 18000 | 0.4605 |
70
- | 2.0261 | 0.7942 | 19000 | 0.4598 |
71
- | 2.0084 | 0.8360 | 20000 | 0.4593 |
72
- | 2.0236 | 0.8778 | 21000 | 0.4590 |
73
- | 2.0205 | 0.9196 | 22000 | 0.4590 |
74
- | 2.0063 | 0.9614 | 23000 | 0.4587 |
75
-
76
-
77
- ### Framework versions
78
-
79
- - Transformers 5.0.0.dev0
80
- - Pytorch 2.8.0+cu128
81
- - Datasets 3.6.0
82
- - Tokenizers 0.22.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ datasets:
6
+ - speechbrain/LoquaciousSet
7
+ base_model:
8
+ - zai-org/GLM-ASR-Nano-2512
9
+ - Qwen/Qwen3-0.6B
10
+ pipeline_tag: automatic-speech-recognition
11
  tags:
12
+ - asr
13
+ - speech-recognition
14
+ - audio
15
+ - qwen
16
+ - glm-asr
17
+ library_name: transformers
18
  ---
19
 
20
+ # Tiny Audio
21
+
22
+ A speech recognition model trained in 24 hours on a single GPU for ~$12. Built with [Tiny Audio](https://github.com/alexkroman/tiny-audio)—a minimal, hackable ASR framework.
23
+
24
+ ## Quick Start
25
+
26
+ ```python
27
+ from transformers import pipeline
28
+
29
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
30
+ result = pipe("audio.wav")
31
+ print(result["text"])
32
+ ```
33
+
34
+ ## Usage Examples
35
+
36
+ ### Basic Transcription
37
+
38
+ ```python
39
+ from transformers import pipeline
40
+
41
+ pipe = pipeline("automatic-speech-recognition", model="mazesmazes/tiny-audio", trust_remote_code=True)
42
+
43
+ # From file
44
+ result = pipe("audio.wav")
45
+ print(result["text"])
46
+
47
+ # From URL
48
+ result = pipe("https://example.com/audio.mp3")
49
+
50
+ # From numpy array (must be 16kHz)
51
+ import numpy as np
52
+ audio = np.random.randn(16000).astype(np.float32) # 1 second
53
+ result = pipe(audio)
54
+ ```
55
+
56
+ ### Batch Processing
57
+
58
+ ```python
59
+ # Process multiple files
60
+ files = ["audio1.wav", "audio2.wav", "audio3.wav"]
61
+ results = pipe(files, batch_size=4)
62
+ for r in results:
63
+ print(r["text"])
64
+ ```
65
+
66
+ ### Word-Level Timestamps
67
+
68
+ ```python
69
+ result = pipe("audio.wav", return_timestamps="word")
70
+ # Returns:
71
+ # {
72
+ # "text": "hello world",
73
+ # "chunks": [
74
+ # {"text": "hello", "timestamp": (0.0, 0.5)},
75
+ # {"text": "world", "timestamp": (0.6, 1.0)}
76
+ # ]
77
+ # }
78
+ ```
79
+
80
+ ### Streaming Inference
81
+
82
+ ```python
83
+ from tiny_audio import ASRModel, ASRProcessor
84
+ import torch
85
+
86
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
87
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
88
+
89
+ # Load and process audio
90
+ import librosa
91
+ audio, sr = librosa.load("audio.wav", sr=16000)
92
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
93
+
94
+ # Stream tokens
95
+ for token in model.generate_streaming(inputs["input_features"]):
96
+ print(token, end="", flush=True)
97
+ ```
98
+
99
+ ### Using with torch directly
100
+
101
+ ```python
102
+ from tiny_audio import ASRModel, ASRProcessor
103
+ import torch
104
+ import librosa
105
+
106
+ # Load model and processor
107
+ model = ASRModel.from_pretrained("mazesmazes/tiny-audio")
108
+ processor = ASRProcessor.from_pretrained("mazesmazes/tiny-audio")
109
+
110
+ # Load audio (16kHz)
111
+ audio, sr = librosa.load("audio.wav", sr=16000)
112
+
113
+ # Process
114
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
115
+
116
+ # Generate
117
+ with torch.no_grad():
118
+ output = model.generate(
119
+ input_features=inputs["input_features"],
120
+ attention_mask=inputs["attention_mask"],
121
+ max_new_tokens=256
122
+ )
123
+
124
+ # Decode
125
+ text = processor.batch_decode(output, skip_special_tokens=True)[0]
126
+ print(text)
127
+ ```
128
+
129
+ ### GPU Inference
130
+
131
+ ```python
132
+ import torch
133
+
134
+ pipe = pipeline(
135
+ "automatic-speech-recognition",
136
+ model="mazesmazes/tiny-audio",
137
+ trust_remote_code=True,
138
+ device="cuda" # or device=0
139
+ )
140
+ ```
141
+
142
+ ### Half Precision
143
+
144
+ ```python
145
+ pipe = pipeline(
146
+ "automatic-speech-recognition",
147
+ model="mazesmazes/tiny-audio",
148
+ trust_remote_code=True,
149
+ torch_dtype=torch.float16,
150
+ device="cuda"
151
+ )
152
+ ```
153
+
154
+ ## Architecture
155
+
156
+ ```
157
+ Audio (16kHz) → GLM-ASR Encoder (frozen) → MLP Projector (trained) → Qwen3 (frozen) → Text
158
+ ```
159
+
160
+ Only the projector is trained (~12M params). The encoder and decoder remain frozen, leveraging their pretrained knowledge.
161
+
162
+ | Component | Model | Parameters | Status |
163
+ |-----------|-------|------------|--------|
164
+ | Audio Encoder | GLM-ASR-Nano-2512 | ~600M | Frozen |
165
+ | Projector | 2-layer MLP | ~12M | Trained |
166
+ | Language Model | Qwen3-0.6B | ~600M | Frozen |
167
+
168
+ ### How It Works
169
+
170
+ 1. **Audio Encoder**: GLM-ASR converts 16kHz audio into frame-level embeddings (768-dim)
171
+ 2. **Projector**: A 2-layer MLP with frame stacking bridges the audio and text embedding spaces
172
+ 3. **Language Model**: Qwen3 generates text autoregressively, conditioned on the projected audio
173
+
174
+ The projector reduces sequence length via frame stacking: `output_len = (input_len - 5) // 5 + 1`
175
+
176
+ ## Model Specifications
177
+
178
+ | Specification | Value |
179
+ |---------------|-------|
180
+ | Input | Audio (16kHz mono) |
181
+ | Output | Text transcription |
182
+ | Max Audio Length | ~30 seconds (limited by encoder) |
183
+ | Vocabulary | Qwen3 tokenizer |
184
+ | Languages | English only |
185
+ | Generation | Greedy decoding (num_beams=1, do_sample=False) |
186
+
187
+ ## Training Details
188
+
189
+ | | |
190
+ |---|---|
191
+ | **Dataset** | LoquaciousSet (25,000 hours) |
192
+ | **Hardware** | Single NVIDIA A40 |
193
+ | **Time** | ~24 hours |
194
+ | **Cost** | ~$12 |
195
+ | **Optimizer** | AdamW |
196
+ | **Learning Rate** | 1e-4 |
197
+ | **Batch Size** | 4 |
198
+ | **Steps** | 50,000 |
199
+
200
+ ## Limitations
201
+
202
+ - **English only**: Not trained on other languages
203
+ - **Sample rate**: Expects 16kHz audio (other rates resampled automatically)
204
+ - **Audio length**: Best for clips under 30 seconds
205
+ - **Accuracy**: May degrade on:
206
+ - Heavily accented speech
207
+ - Noisy or low-quality audio
208
+ - Domain-specific terminology
209
+ - Overlapping speakers
210
+ - **No punctuation**: Output is lowercase without punctuation by default
211
+
212
+ ## Requirements
213
+
214
+ ```
215
+ transformers>=4.40.0
216
+ torch>=2.0.0
217
+ torchaudio>=2.0.0
218
+ ```
219
+
220
+ Optional for streaming:
221
+ ```
222
+ librosa
223
+ soundfile
224
+ ```
225
+
226
+ ## Files
227
+
228
+ | File | Description |
229
+ |------|-------------|
230
+ | `config.json` | Model configuration |
231
+ | `model.safetensors` | Projector weights (~48MB) |
232
+ | `preprocessor_config.json` | Audio preprocessing config |
233
+ | `tokenizer.json` | Tokenizer |
234
+ | `tokenizer_config.json` | Tokenizer config |
235
+ | `special_tokens_map.json` | Special tokens |
236
+
237
+ Note: Only the projector weights are stored. The encoder (GLM-ASR) and decoder (Qwen3) are loaded from their respective HuggingFace repos.
238
+
239
+ ## Citation
240
+
241
+ If you use this model, please cite:
242
+
243
+ ```bibtex
244
+ @misc{tinyaudio2024,
245
+ author = {Alex Kroman},
246
+ title = {Tiny Audio: Minimal ASR Training},
247
+ year = {2024},
248
+ publisher = {GitHub},
249
+ url = {https://github.com/alexkroman/tiny-audio}
250
+ }
251
+ ```
252
+
253
+ ## Links
254
+
255
+ - [GitHub Repository](https://github.com/alexkroman/tiny-audio) - Train your own model
256
+ - [Free 3.5-hour Course](https://github.com/alexkroman/tiny-audio/blob/main/docs/course/0-course-overview.md) - Learn ASR from scratch
257
+ - [Live Demo](https://huggingface.co/spaces/mazesmazes/tiny-audio) - Try it in your browser
258
+
259
+ ## Acknowledgments
260
+
261
+ - [GLM-ASR](https://huggingface.co/zai-org/GLM-ASR-Nano-2512) for the audio encoder
262
+ - [Qwen3](https://huggingface.co/Qwen/Qwen3-0.6B) for the language model
263
+ - [LoquaciousSet](https://huggingface.co/datasets/speechbrain/LoquaciousSet) for training data
264
+
265
+ ## License
266
+
267
+ MIT
asr_pipeline.py CHANGED
@@ -1,6 +1,7 @@
1
  """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
 
3
  import re
 
4
  from pathlib import Path
5
  from typing import Any
6
 
@@ -23,8 +24,135 @@ def _get_device() -> str:
23
  return "cpu"
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class ForcedAligner:
27
- """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2."""
 
 
 
 
28
 
29
  _bundle = None
30
  _model = None
@@ -44,7 +172,8 @@ class ForcedAligner:
44
  if cls._model is None:
45
  import torchaudio
46
 
47
- cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
 
48
  cls._model = cls._bundle.get_model().to(device)
49
  cls._model.eval()
50
  cls._labels = cls._bundle.get_labels()
@@ -57,28 +186,29 @@ class ForcedAligner:
57
  audio: np.ndarray,
58
  text: str,
59
  sample_rate: int = 16000,
60
- _language: str = "eng",
61
  _batch_size: int = 16,
62
  ) -> list[dict]:
63
  """Align transcript to audio and return word-level timestamps.
64
 
 
 
65
  Args:
66
  audio: Audio waveform as numpy array
67
  text: Transcript text to align
68
  sample_rate: Audio sample rate (default 16000)
69
- _language: ISO-639-3 language code (default "eng" for English, unused)
70
- _batch_size: Batch size for alignment model (unused)
71
 
72
  Returns:
73
  List of dicts with 'word', 'start', 'end' keys
74
  """
75
  import torchaudio
76
- from torchaudio.functional import forced_align, merge_tokens
77
 
78
  device = _get_device()
79
  model, labels, dictionary = cls.get_instance(device)
80
 
81
- # Convert audio to tensor (copy to ensure array is writable)
82
  if isinstance(audio, np.ndarray):
83
  waveform = torch.from_numpy(audio.copy()).float()
84
  else:
@@ -88,7 +218,7 @@ class ForcedAligner:
88
  if waveform.dim() == 1:
89
  waveform = waveform.unsqueeze(0)
90
 
91
- # Resample if needed (wav2vec2 expects 16kHz)
92
  if sample_rate != cls._bundle.sample_rate:
93
  waveform = torchaudio.functional.resample(
94
  waveform, sample_rate, cls._bundle.sample_rate
@@ -103,67 +233,47 @@ class ForcedAligner:
103
 
104
  emission = emissions[0].cpu()
105
 
106
- # Normalize text: uppercase, keep only valid characters
107
  transcript = text.upper()
108
- # Build tokens from transcript
109
  tokens = []
 
 
110
  for char in transcript:
111
  if char in dictionary:
112
  tokens.append(dictionary[char])
 
113
  elif char == " ":
114
- tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
 
 
115
 
116
  if not tokens:
117
  return []
118
 
119
- targets = torch.tensor([tokens], dtype=torch.int32)
120
-
121
- # Run forced alignment
122
- # Note: forced_align is deprecated in torchaudio 2.6+ and will be removed in 2.9 (late 2025)
123
- # No official replacement announced yet. See https://github.com/pytorch/audio/issues/3902
124
- aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0)
125
 
126
- # Use torchaudio's merge_tokens to get token spans (removes blanks and merges repeats)
127
- token_spans = merge_tokens(aligned_tokens[0], scores[0])
 
128
 
129
- # Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
130
- frame_duration = 320 / cls._bundle.sample_rate
131
 
132
- # Group token spans into words based on pipe separator
133
  words = text.split()
134
  word_timestamps = []
135
- current_word_start = None
136
- current_word_end = None
137
- word_idx = 0
138
-
139
- for span in token_spans:
140
- token_char = labels[span.token]
141
- if token_char == "|": # Word separator
142
- if current_word_start is not None and word_idx < len(words):
143
- word_timestamps.append(
144
- {
145
- "word": words[word_idx],
146
- "start": current_word_start * frame_duration,
147
- "end": current_word_end * frame_duration,
148
- }
149
- )
150
- word_idx += 1
151
- current_word_start = None
152
- current_word_end = None
153
- else:
154
- if current_word_start is None:
155
- current_word_start = span.start
156
- current_word_end = span.end
157
-
158
- # Don't forget the last word
159
- if current_word_start is not None and word_idx < len(words):
160
- word_timestamps.append(
161
- {
162
- "word": words[word_idx],
163
- "start": current_word_start * frame_duration,
164
- "end": current_word_end * frame_duration,
165
- }
166
- )
167
 
168
  return word_timestamps
169
 
 
1
  """ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
2
 
3
  import re
4
+ from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Any
7
 
 
24
  return "cpu"
25
 
26
 
27
+ @dataclass
28
+ class _AlignPoint:
29
+ """A point in the alignment path."""
30
+
31
+ token_index: int
32
+ time_index: int
33
+ score: float
34
+
35
+
36
+ @dataclass
37
+ class _AlignSegment:
38
+ """An aligned character/word segment."""
39
+
40
+ label: str
41
+ start: int
42
+ end: int
43
+ score: float
44
+
45
+ @property
46
+ def length(self):
47
+ return self.end - self.start
48
+
49
+
50
+ def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
51
+ """Build dynamic programming trellis for CTC alignment.
52
+
53
+ Based on WhisperX's alignment algorithm for improved accuracy.
54
+ """
55
+ num_frame = emission.size(0)
56
+ num_tokens = len(tokens)
57
+
58
+ trellis = torch.zeros((num_frame, num_tokens))
59
+ trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
60
+ trellis[0, 1:] = -float("inf")
61
+ trellis[-num_tokens + 1 :, 0] = float("inf")
62
+
63
+ for t in range(num_frame - 1):
64
+ trellis[t + 1, 1:] = torch.maximum(
65
+ # Score for staying at the same token
66
+ trellis[t, 1:] + emission[t, blank_id],
67
+ # Score for changing to the next token
68
+ trellis[t, :-1] + emission[t, tokens[1:]],
69
+ )
70
+ return trellis
71
+
72
+
73
+ def _backtrack(
74
+ trellis: torch.Tensor,
75
+ emission: torch.Tensor,
76
+ tokens: list[int],
77
+ blank_id: int = 0,
78
+ ) -> list[_AlignPoint]:
79
+ """Backtrack through trellis to find optimal alignment path."""
80
+ t, j = trellis.size(0) - 1, trellis.size(1) - 1
81
+
82
+ path = [_AlignPoint(j, t, emission[t, blank_id].exp().item())]
83
+ while j > 0:
84
+ assert t > 0
85
+
86
+ p_stay = emission[t - 1, blank_id]
87
+ p_change = emission[t - 1, tokens[j]]
88
+
89
+ stayed = trellis[t - 1, j] + p_stay
90
+ changed = trellis[t - 1, j - 1] + p_change
91
+
92
+ t -= 1
93
+ if changed > stayed:
94
+ j -= 1
95
+
96
+ prob = (p_change if changed > stayed else p_stay).exp().item()
97
+ path.append(_AlignPoint(j, t, prob))
98
+
99
+ while t > 0:
100
+ prob = emission[t - 1, blank_id].exp().item()
101
+ path.append(_AlignPoint(j, t - 1, prob))
102
+ t -= 1
103
+
104
+ return path[::-1]
105
+
106
+
107
+ def _merge_repeats(path: list[_AlignPoint], transcript: str) -> list[_AlignSegment]:
108
+ """Merge repeated tokens into character segments."""
109
+ i1, i2 = 0, 0
110
+ segments = []
111
+ while i1 < len(path):
112
+ while i2 < len(path) and path[i1].token_index == path[i2].token_index:
113
+ i2 += 1
114
+ score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
115
+ segments.append(
116
+ _AlignSegment(
117
+ transcript[path[i1].token_index],
118
+ path[i1].time_index,
119
+ path[i2 - 1].time_index + 1,
120
+ score,
121
+ )
122
+ )
123
+ i1 = i2
124
+ return segments
125
+
126
+
127
+ def _merge_words(segments: list[_AlignSegment], separator: str = "|") -> list[_AlignSegment]:
128
+ """Merge character segments into word segments."""
129
+ words = []
130
+ i1, i2 = 0, 0
131
+ while i1 < len(segments):
132
+ if i2 >= len(segments) or segments[i2].label == separator:
133
+ if i1 != i2:
134
+ segs = segments[i1:i2]
135
+ word = "".join([seg.label for seg in segs])
136
+ total_length = sum(seg.length for seg in segs)
137
+ score = (
138
+ sum(seg.score * seg.length for seg in segs) / total_length
139
+ if total_length > 0
140
+ else 0
141
+ )
142
+ words.append(_AlignSegment(word, segments[i1].start, segments[i2 - 1].end, score))
143
+ i1 = i2 + 1
144
+ i2 = i1
145
+ else:
146
+ i2 += 1
147
+ return words
148
+
149
+
150
  class ForcedAligner:
151
+ """Forced aligner for word-level timestamps using wav2vec2.
152
+
153
+ Uses WhisperX-style dynamic programming alignment for improved accuracy
154
+ over simple CTC greedy alignment.
155
+ """
156
 
157
  _bundle = None
158
  _model = None
 
172
  if cls._model is None:
173
  import torchaudio
174
 
175
+ # Use LARGE model for better accuracy (same as WhisperX recommendation)
176
+ cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_960H
177
  cls._model = cls._bundle.get_model().to(device)
178
  cls._model.eval()
179
  cls._labels = cls._bundle.get_labels()
 
186
  audio: np.ndarray,
187
  text: str,
188
  sample_rate: int = 16000,
189
+ _language: str = "en",
190
  _batch_size: int = 16,
191
  ) -> list[dict]:
192
  """Align transcript to audio and return word-level timestamps.
193
 
194
+ Uses WhisperX-style dynamic programming for improved alignment accuracy.
195
+
196
  Args:
197
  audio: Audio waveform as numpy array
198
  text: Transcript text to align
199
  sample_rate: Audio sample rate (default 16000)
200
+ _language: Language code (unused, English only)
201
+ _batch_size: Batch size (unused)
202
 
203
  Returns:
204
  List of dicts with 'word', 'start', 'end' keys
205
  """
206
  import torchaudio
 
207
 
208
  device = _get_device()
209
  model, labels, dictionary = cls.get_instance(device)
210
 
211
+ # Convert audio to tensor
212
  if isinstance(audio, np.ndarray):
213
  waveform = torch.from_numpy(audio.copy()).float()
214
  else:
 
218
  if waveform.dim() == 1:
219
  waveform = waveform.unsqueeze(0)
220
 
221
+ # Resample if needed
222
  if sample_rate != cls._bundle.sample_rate:
223
  waveform = torchaudio.functional.resample(
224
  waveform, sample_rate, cls._bundle.sample_rate
 
233
 
234
  emission = emissions[0].cpu()
235
 
236
+ # Normalize text and build token sequence
237
  transcript = text.upper()
 
238
  tokens = []
239
+ clean_transcript = ""
240
+
241
  for char in transcript:
242
  if char in dictionary:
243
  tokens.append(dictionary[char])
244
+ clean_transcript += char
245
  elif char == " ":
246
+ sep_token = dictionary.get("|", dictionary.get(" ", 0))
247
+ tokens.append(sep_token)
248
+ clean_transcript += "|"
249
 
250
  if not tokens:
251
  return []
252
 
253
+ # Build trellis and find optimal path (WhisperX-style DP alignment)
254
+ trellis = _get_trellis(emission, tokens, blank_id=0)
255
+ path = _backtrack(trellis, emission, tokens, blank_id=0)
 
 
 
256
 
257
+ # Merge into character segments, then word segments
258
+ char_segments = _merge_repeats(path, clean_transcript)
259
+ word_segments = _merge_words(char_segments, separator="|")
260
 
261
+ # Convert frame indices to time
262
+ frame_duration = 320 / cls._bundle.sample_rate # 20ms per frame
263
 
264
+ # Build output with original words
265
  words = text.split()
266
  word_timestamps = []
267
+
268
+ for i, seg in enumerate(word_segments):
269
+ if i < len(words):
270
+ word_timestamps.append(
271
+ {
272
+ "word": words[i],
273
+ "start": seg.start * frame_duration,
274
+ "end": seg.end * frame_duration,
275
+ }
276
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  return word_timestamps
279