txya900619 commited on
Commit
297b43c
·
1 Parent(s): 4c9914b

feat: refactor app.py and use whisperx

Browse files
Files changed (3) hide show
  1. app.py +28 -23
  2. requirements.txt +1 -1
  3. whisper.py +539 -0
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import tempfile
2
 
3
  import gradio as gr
4
- from faster_whisper import BatchedInferencePipeline, WhisperModel
 
5
 
6
  try:
7
  import spaces
@@ -10,6 +11,8 @@ try:
10
  except ImportError:
11
  USING_SPACES = False
12
 
 
 
13
 
14
  def gpu_decorator(func):
15
  if USING_SPACES:
@@ -18,10 +21,11 @@ def gpu_decorator(func):
18
  return func
19
 
20
 
21
- model = WhisperModel(
22
  "formospeech/whisper-large-v2-formosan-all-ct2",
 
 
23
  )
24
- model = BatchedInferencePipeline(model=model)
25
 
26
  with gr.Blocks() as demo:
27
  gr.Markdown(
@@ -52,9 +56,9 @@ with gr.Blocks() as demo:
52
  )
53
  min_silence_duration_ms_slider = gr.Slider(
54
  label="min silence duration ms",
55
- minimum=50,
56
- maximum=1000,
57
- step=50,
58
  value=150,
59
  info="Minimum duration of silence (in ms) to consider a segment as speech.",
60
  )
@@ -67,30 +71,31 @@ with gr.Blocks() as demo:
67
  )
68
 
69
  @gpu_decorator
70
- def generate_srt(audio, threshold, neg_threshold, min_silence_duration_ms=500):
71
- segments, info = model.transcribe(
 
 
72
  audio,
73
  language="id",
74
- beam_size=5,
75
- vad_filter=True,
76
- vad_parameters={
77
- "threshold": threshold,
78
- "min_silence_duration_ms": min_silence_duration_ms,
79
- },
80
  batch_size=32,
81
  )
 
 
 
 
82
  srt_content = ""
83
- for segment in segments:
84
- srt_content += f"{segment.id}\n"
85
-
86
- # convert float seconds to SRT time format
87
- start_time = segment.start
88
- end_time = segment.end
89
- start_time_srt = f"{int(start_time // 3600):02}:{int((start_time % 3600) // 60):02}:{int(start_time % 60):02},{int((start_time % 1) * 1000):03}"
90
- end_time_srt = f"{int(end_time // 3600):02}:{int((end_time % 3600) // 60):02}:{int(end_time % 60):02},{int((end_time % 1) * 1000):03}"
 
91
  srt_content += f"{start_time_srt} --> {end_time_srt}\n"
92
 
93
- srt_content += f"族語:{segment.text.strip()}\n"
94
  srt_content += "華語:\n\n"
95
 
96
  return srt_content.strip()
 
1
  import tempfile
2
 
3
  import gradio as gr
4
+
5
+ from whisper import load_audio, load_model
6
 
7
  try:
8
  import spaces
 
11
  except ImportError:
12
  USING_SPACES = False
13
 
14
+ SAMPLING_RATE = 16000
15
+
16
 
17
  def gpu_decorator(func):
18
  if USING_SPACES:
 
21
  return func
22
 
23
 
24
+ model = load_model(
25
  "formospeech/whisper-large-v2-formosan-all-ct2",
26
+ device="cuda",
27
+ asr_options={"word_timestamps": True},
28
  )
 
29
 
30
  with gr.Blocks() as demo:
31
  gr.Markdown(
 
56
  )
57
  min_silence_duration_ms_slider = gr.Slider(
58
  label="min silence duration ms",
59
+ minimum=10,
60
+ maximum=500,
61
+ step=10,
62
  value=150,
63
  info="Minimum duration of silence (in ms) to consider a segment as speech.",
64
  )
 
71
  )
72
 
73
  @gpu_decorator
74
+ def generate_srt(audio, threshold, min_silence_duration_ms=500):
75
+ audio = load_audio(audio, sr=SAMPLING_RATE)
76
+
77
+ output = model.transcribe(
78
  audio,
79
  language="id",
 
 
 
 
 
 
80
  batch_size=32,
81
  )
82
+
83
+ segments = output["segments"]
84
+ print(segments)
85
+
86
  srt_content = ""
87
+
88
+ for i, segment in enumerate(segments):
89
+ start_seconds = segment["start"]
90
+ end_seconds = segment["end"]
91
+
92
+ srt_content += f"{i + 1}\n"
93
+
94
+ start_time_srt = f"{int(start_seconds // 3600):02}:{int((start_seconds % 3600) // 60):02}:{int(start_seconds % 60):02},{int((start_seconds % 1) * 1000):03}"
95
+ end_time_srt = f"{int(end_seconds // 3600):02}:{int((end_seconds % 3600) // 60):02}:{int(end_seconds % 60):02},{int((end_seconds % 1) * 1000):03}"
96
  srt_content += f"{start_time_srt} --> {end_time_srt}\n"
97
 
98
+ srt_content += f"族語:{segment['text']}\n"
99
  srt_content += "華語:\n\n"
100
 
101
  return srt_content.strip()
requirements.txt CHANGED
@@ -1 +1 @@
1
- faster_whisper
 
1
+ whisperx
whisper.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import replace
3
+ from math import ceil
4
+ from typing import List, Optional, Union
5
+
6
+ import ctranslate2
7
+ import faster_whisper
8
+ import numpy as np
9
+ import torch
10
+ from faster_whisper.tokenizer import Tokenizer
11
+ from faster_whisper.transcribe import TranscriptionOptions, get_ctranslate2_storage
12
+ from transformers import Pipeline
13
+ from transformers.pipelines.pt_utils import PipelineIterator
14
+ from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
15
+ from whisperx.types import SingleSegment, TranscriptionResult
16
+ from whisperx.vads import Pyannote, Silero, Vad
17
+ from whisperx.vads.pyannote import Binarize
18
+
19
+
20
+ def find_numeral_symbol_tokens(tokenizer):
21
+ numeral_symbol_tokens = []
22
+ for i in range(tokenizer.eot):
23
+ token = tokenizer.decode([i]).removeprefix(" ")
24
+ has_numeral_symbol = any(c in "0123456789%$£" for c in token)
25
+ if has_numeral_symbol:
26
+ numeral_symbol_tokens.append(i)
27
+ return numeral_symbol_tokens
28
+
29
+
30
+ class WhisperModel(faster_whisper.WhisperModel):
31
+ """
32
+ FasterWhisperModel provides batched inference for faster-whisper.
33
+ Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
34
+ """
35
+
36
+ def generate_segment_batched(
37
+ self,
38
+ features: np.ndarray,
39
+ tokenizer: Tokenizer,
40
+ options: TranscriptionOptions,
41
+ encoder_output=None,
42
+ ):
43
+ batch_size = features.shape[0]
44
+ all_tokens = []
45
+ prompt_reset_since = 0
46
+ if options.initial_prompt is not None:
47
+ initial_prompt = " " + options.initial_prompt.strip()
48
+ initial_prompt_tokens = tokenizer.encode(initial_prompt)
49
+ all_tokens.extend(initial_prompt_tokens)
50
+ previous_tokens = all_tokens[prompt_reset_since:]
51
+ prompt = self.get_prompt(
52
+ tokenizer,
53
+ previous_tokens,
54
+ without_timestamps=options.without_timestamps,
55
+ prefix=options.prefix,
56
+ hotwords=options.hotwords,
57
+ )
58
+
59
+ encoder_output = self.encode(features)
60
+
61
+ max_initial_timestamp_index = int(
62
+ round(options.max_initial_timestamp / self.time_precision)
63
+ )
64
+
65
+ result = self.model.generate(
66
+ encoder_output,
67
+ [prompt] * batch_size,
68
+ beam_size=options.beam_size,
69
+ patience=options.patience,
70
+ length_penalty=options.length_penalty,
71
+ max_length=self.max_length,
72
+ suppress_blank=options.suppress_blank,
73
+ suppress_tokens=options.suppress_tokens,
74
+ )
75
+
76
+ tokens_batch = [x.sequences_ids[0] for x in result]
77
+
78
+ def decode_batch(tokens: List[List[int]]) -> str:
79
+ res = []
80
+ for tk in tokens:
81
+ res.append([token for token in tk if token < tokenizer.eot])
82
+ # text_tokens = [token for token in tokens if token < self.eot]
83
+ return tokenizer.tokenizer.decode_batch(res)
84
+
85
+ text = decode_batch(tokens_batch)
86
+
87
+ return encoder_output, text, tokens_batch
88
+
89
+ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
90
+ # When the model is running on multiple GPUs, the encoder output should be moved
91
+ # to the CPU since we don't know which GPU will handle the next job.
92
+ to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
93
+ # unsqueeze if batch size = 1
94
+ if len(features.shape) == 2:
95
+ features = np.expand_dims(features, 0)
96
+ features = get_ctranslate2_storage(features)
97
+
98
+ return self.model.encode(features, to_cpu=to_cpu)
99
+
100
+
101
+ class FasterWhisperPipeline(Pipeline):
102
+ """
103
+ Huggingface Pipeline wrapper for FasterWhisperModel.
104
+ """
105
+
106
+ # TODO:
107
+ # - add support for timestamp mode
108
+ # - add support for custom inference kwargs
109
+
110
+ def __init__(
111
+ self,
112
+ model: WhisperModel,
113
+ vad,
114
+ vad_params: dict,
115
+ options: TranscriptionOptions,
116
+ tokenizer: Optional[Tokenizer] = None,
117
+ device: Union[int, str, "torch.device"] = -1,
118
+ framework="pt",
119
+ language: Optional[str] = None,
120
+ suppress_numerals: bool = False,
121
+ **kwargs,
122
+ ):
123
+ self.model = model
124
+ self.tokenizer = tokenizer
125
+ self.options = options
126
+ self.preset_language = language
127
+ self.suppress_numerals = suppress_numerals
128
+ self._batch_size = kwargs.pop("batch_size", None)
129
+ self._num_workers = 1
130
+ self._preprocess_params, self._forward_params, self._postprocess_params = (
131
+ self._sanitize_parameters(**kwargs)
132
+ )
133
+ self.call_count = 0
134
+ self.framework = framework
135
+ if self.framework == "pt":
136
+ if isinstance(device, torch.device):
137
+ self.device = device
138
+ elif isinstance(device, str):
139
+ self.device = torch.device(device)
140
+ elif device < 0:
141
+ self.device = torch.device("cpu")
142
+ else:
143
+ self.device = torch.device(f"cuda:{device}")
144
+ else:
145
+ self.device = device
146
+
147
+ super(Pipeline, self).__init__()
148
+ self.vad_model = vad
149
+ self._vad_params = vad_params
150
+ self.last_speech_timestamp = 0.0
151
+
152
+ def _sanitize_parameters(self, **kwargs):
153
+ preprocess_kwargs = {}
154
+ if "tokenizer" in kwargs:
155
+ preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
156
+ return preprocess_kwargs, {}, {}
157
+
158
+ def preprocess(self, input_dict):
159
+ audio = input_dict["inputs"]
160
+
161
+ model_n_mels = self.model.feat_kwargs.get("feature_size")
162
+ features = log_mel_spectrogram(
163
+ audio,
164
+ n_mels=model_n_mels if model_n_mels is not None else 80,
165
+ padding=N_SAMPLES - audio.shape[0],
166
+ )
167
+ return {
168
+ "inputs": features,
169
+ "start": input_dict["start"],
170
+ "end": input_dict["end"],
171
+ "segment_size": input_dict["segment_size"],
172
+ }
173
+
174
+ def _forward(self, model_inputs):
175
+ encoder_output, text, tokens = self.model.generate_segment_batched(
176
+ model_inputs["inputs"], self.tokenizer, self.options
177
+ )
178
+ outputs = [
179
+ [
180
+ {
181
+ "tokens": tokens[i],
182
+ "start": model_inputs["start"][i],
183
+ "end": model_inputs["end"][i],
184
+ "seek": int(model_inputs["start"][i] * 100),
185
+ }
186
+ ]
187
+ for i in range(len(tokens))
188
+ ]
189
+
190
+ self.last_speech_timestamp = self.model.add_word_timestamps(
191
+ outputs,
192
+ self.tokenizer,
193
+ encoder_output,
194
+ num_frames=model_inputs["segment_size"],
195
+ prepend_punctuations="\"'“¿([{-",
196
+ append_punctuations="\"'.。,,!!??::”)]}、",
197
+ last_speech_timestamp=self.last_speech_timestamp,
198
+ )
199
+
200
+ outputs = [outputs[i][0]["words"] for i in range(len(outputs))]
201
+ outputs = sum(outputs, [])
202
+ return {
203
+ "words": [outputs],
204
+ }
205
+
206
+ def postprocess(self, model_outputs):
207
+ return model_outputs
208
+
209
+ def get_iterator(
210
+ self,
211
+ inputs,
212
+ num_workers: int,
213
+ batch_size: int,
214
+ preprocess_params: dict,
215
+ forward_params: dict,
216
+ postprocess_params: dict,
217
+ ):
218
+ dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
219
+ if "TOKENIZERS_PARALLELISM" not in os.environ:
220
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
221
+ # TODO hack by collating feature_extractor and image_processor
222
+
223
+ def stack(items):
224
+ return {
225
+ "inputs": torch.stack([x["inputs"] for x in items]),
226
+ "start": [x["start"] for x in items],
227
+ "end": [x["end"] for x in items],
228
+ "segment_size": [x["segment_size"] for x in items],
229
+ }
230
+
231
+ dataloader = torch.utils.data.DataLoader(
232
+ dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack
233
+ )
234
+ model_iterator = PipelineIterator(
235
+ dataloader, self.forward, forward_params, loader_batch_size=batch_size
236
+ )
237
+ final_iterator = PipelineIterator(
238
+ model_iterator, self.postprocess, postprocess_params
239
+ )
240
+ return final_iterator
241
+
242
+ def transcribe(
243
+ self,
244
+ audio: Union[str, np.ndarray],
245
+ batch_size: Optional[int] = None,
246
+ num_workers=0,
247
+ language: Optional[str] = None,
248
+ task: Optional[str] = None,
249
+ chunk_size=30,
250
+ print_progress=False,
251
+ combined_progress=False,
252
+ verbose=False,
253
+ ) -> TranscriptionResult:
254
+ if isinstance(audio, str):
255
+ audio = load_audio(audio)
256
+
257
+ def data(audio, segments):
258
+ for seg in segments:
259
+ f1 = int(seg["start"] * SAMPLE_RATE)
260
+ f2 = int(seg["end"] * SAMPLE_RATE)
261
+ # print(f2-f1)
262
+ yield {
263
+ "inputs": audio[f1:f2],
264
+ "start": seg["start"],
265
+ "end": seg["end"],
266
+ "segment_size": int(
267
+ ceil(seg["end"] - seg["start"]) * self.model.frames_per_second
268
+ ),
269
+ }
270
+
271
+ # Pre-process audio and merge chunks as defined by the respective VAD child class
272
+ # In case vad_model is manually assigned (see 'load_model') follow the functionality of pyannote toolkit
273
+ if issubclass(type(self.vad_model), Vad):
274
+ waveform = self.vad_model.preprocess_audio(audio)
275
+ merge_chunks = self.vad_model.merge_chunks
276
+ else:
277
+ waveform = Pyannote.preprocess_audio(audio)
278
+ merge_chunks = Pyannote.merge_chunks
279
+
280
+ pre_merge_vad_segments = self.vad_model(
281
+ {"waveform": waveform, "sample_rate": SAMPLE_RATE}
282
+ )
283
+ vad_segments = merge_chunks(
284
+ pre_merge_vad_segments,
285
+ chunk_size,
286
+ onset=self._vad_params["vad_onset"],
287
+ offset=self._vad_params["vad_offset"],
288
+ )
289
+ if self.tokenizer is None:
290
+ language = language or self.detect_language(audio)
291
+ task = task or "transcribe"
292
+ self.tokenizer = Tokenizer(
293
+ self.model.hf_tokenizer,
294
+ self.model.model.is_multilingual,
295
+ task=task,
296
+ language=language,
297
+ )
298
+ else:
299
+ language = language or self.tokenizer.language_code
300
+ task = task or self.tokenizer.task
301
+ if task != self.tokenizer.task or language != self.tokenizer.language_code:
302
+ self.tokenizer = Tokenizer(
303
+ self.model.hf_tokenizer,
304
+ self.model.model.is_multilingual,
305
+ task=task,
306
+ language=language,
307
+ )
308
+
309
+ if self.suppress_numerals:
310
+ previous_suppress_tokens = self.options.suppress_tokens
311
+ numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
312
+ print("Suppressing numeral and symbol tokens")
313
+ new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
314
+ new_suppressed_tokens = list(set(new_suppressed_tokens))
315
+ self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
316
+
317
+ binarize = Binarize(
318
+ max_duration=chunk_size,
319
+ onset=self._vad_params["vad_onset"],
320
+ offset=self._vad_params["vad_offset"],
321
+ )
322
+ segments = binarize(pre_merge_vad_segments).get_timeline()
323
+ segments: List[SingleSegment] = [
324
+ {
325
+ "start": seg.start,
326
+ "end": seg.end,
327
+ "text": "",
328
+ }
329
+ for seg in segments
330
+ ]
331
+
332
+ batch_size = batch_size or self._batch_size
333
+ total_segments = len(vad_segments)
334
+ for idx, out in enumerate(
335
+ self.__call__(
336
+ data(audio, vad_segments),
337
+ batch_size=batch_size,
338
+ num_workers=num_workers,
339
+ )
340
+ ):
341
+ if print_progress:
342
+ base_progress = ((idx + 1) / total_segments) * 100
343
+ percent_complete = (
344
+ base_progress / 2 if combined_progress else base_progress
345
+ )
346
+ print(f"Progress: {percent_complete:.2f}%...")
347
+
348
+ last_speech_timestamp_index = 0
349
+ next_last_speech_timestamp_index = 0
350
+ for word in out["words"]:
351
+ possiable_segment_indices = []
352
+
353
+ for i, segment in enumerate(segments[last_speech_timestamp_index:]):
354
+ if segment["end"] < word["start"]:
355
+ next_last_speech_timestamp_index = i + 1
356
+ overlap_start = max(segment["start"], word["start"])
357
+ overlap_end = min(segment["end"], word["end"])
358
+ if overlap_start <= overlap_end:
359
+ possiable_segment_indices.append(
360
+ last_speech_timestamp_index + i
361
+ )
362
+ last_speech_timestamp_index = next_last_speech_timestamp_index
363
+
364
+ if len(possiable_segment_indices) == 0:
365
+ print(
366
+ f"Warning: Word '{word['word']}' at [{round(word['start'], 3)} --> {round(word['end'], 3)}] is not in any segment."
367
+ )
368
+ else:
369
+ largest_overlap = -1
370
+ best_segment_index = None
371
+ for i in possiable_segment_indices:
372
+ segment = segments[i]
373
+ overlap_start = max(segment["start"], word["start"])
374
+ overlap_end = min(segment["end"], word["end"])
375
+ overlap_duration = overlap_end - overlap_start
376
+ if overlap_duration > largest_overlap:
377
+ largest_overlap = overlap_duration
378
+ best_segment_index = i
379
+ segments[best_segment_index]["text"] += word["word"]
380
+ # revert the tokenizer if multilingual inference is enabled
381
+ if self.preset_language is None:
382
+ self.tokenizer = None
383
+
384
+ # revert suppressed tokens if suppress_numerals is enabled
385
+ if self.suppress_numerals:
386
+ self.options = replace(
387
+ self.options, suppress_tokens=previous_suppress_tokens
388
+ )
389
+
390
+ return {"segments": segments, "language": language}
391
+
392
+ def detect_language(self, audio: np.ndarray) -> str:
393
+ if audio.shape[0] < N_SAMPLES:
394
+ print(
395
+ "Warning: audio is shorter than 30s, language detection may be inaccurate."
396
+ )
397
+ model_n_mels = self.model.feat_kwargs.get("feature_size")
398
+ segment = log_mel_spectrogram(
399
+ audio[:N_SAMPLES],
400
+ n_mels=model_n_mels if model_n_mels is not None else 80,
401
+ padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0],
402
+ )
403
+ encoder_output = self.model.encode(segment)
404
+ results = self.model.model.detect_language(encoder_output)
405
+ language_token, language_probability = results[0][0]
406
+ language = language_token[2:-2]
407
+ print(
408
+ f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..."
409
+ )
410
+ return language
411
+
412
+
413
+ def load_model(
414
+ whisper_arch: str,
415
+ device: str,
416
+ device_index=0,
417
+ compute_type="float16",
418
+ asr_options: Optional[dict] = None,
419
+ language: Optional[str] = None,
420
+ vad_model: Optional[Vad] = None,
421
+ vad_method: Optional[str] = "pyannote",
422
+ vad_options: Optional[dict] = None,
423
+ model: Optional[WhisperModel] = None,
424
+ task="transcribe",
425
+ download_root: Optional[str] = None,
426
+ local_files_only=False,
427
+ threads=4,
428
+ ) -> FasterWhisperPipeline:
429
+ """Load a Whisper model for inference.
430
+ Args:
431
+ whisper_arch - The name of the Whisper model to load.
432
+ device - The device to load the model on.
433
+ compute_type - The compute type to use for the model.
434
+ vad_method - The vad method to use. vad_model has higher priority if is not None.
435
+ options - A dictionary of options to use for the model.
436
+ language - The language of the model. (use English for now)
437
+ model - The WhisperModel instance to use.
438
+ download_root - The root directory to download the model to.
439
+ local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
440
+ threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
441
+ Returns:
442
+ A Whisper pipeline.
443
+ """
444
+
445
+ if whisper_arch.endswith(".en"):
446
+ language = "en"
447
+
448
+ model = model or WhisperModel(
449
+ whisper_arch,
450
+ device=device,
451
+ device_index=device_index,
452
+ compute_type=compute_type,
453
+ download_root=download_root,
454
+ local_files_only=local_files_only,
455
+ cpu_threads=threads,
456
+ )
457
+ if language is not None:
458
+ tokenizer = Tokenizer(
459
+ model.hf_tokenizer,
460
+ model.model.is_multilingual,
461
+ task=task,
462
+ language=language,
463
+ )
464
+ else:
465
+ print(
466
+ "No language specified, language will be first be detected for each audio file (increases inference time)."
467
+ )
468
+ tokenizer = None
469
+
470
+ default_asr_options = {
471
+ "beam_size": 5,
472
+ "best_of": 5,
473
+ "patience": 1,
474
+ "length_penalty": 1,
475
+ "repetition_penalty": 1,
476
+ "no_repeat_ngram_size": 0,
477
+ "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
478
+ "compression_ratio_threshold": 2.4,
479
+ "log_prob_threshold": -1.0,
480
+ "no_speech_threshold": 0.6,
481
+ "condition_on_previous_text": False,
482
+ "prompt_reset_on_temperature": 0.5,
483
+ "initial_prompt": None,
484
+ "prefix": None,
485
+ "suppress_blank": True,
486
+ "suppress_tokens": [-1],
487
+ "without_timestamps": True,
488
+ "max_initial_timestamp": 0.0,
489
+ "word_timestamps": False,
490
+ "prepend_punctuations": "\"'“¿([{-",
491
+ "append_punctuations": "\"'.。,,!!??::”)]}、",
492
+ "multilingual": model.model.is_multilingual,
493
+ "suppress_numerals": False,
494
+ "max_new_tokens": None,
495
+ "clip_timestamps": None,
496
+ "hallucination_silence_threshold": None,
497
+ "hotwords": None,
498
+ }
499
+
500
+ if asr_options is not None:
501
+ default_asr_options.update(asr_options)
502
+
503
+ suppress_numerals = default_asr_options["suppress_numerals"]
504
+ del default_asr_options["suppress_numerals"]
505
+
506
+ default_asr_options = TranscriptionOptions(**default_asr_options)
507
+
508
+ default_vad_options = {
509
+ "chunk_size": 30, # needed by silero since binarization happens before merge_chunks
510
+ "vad_onset": 0.500,
511
+ "vad_offset": 0.363,
512
+ }
513
+
514
+ if vad_options is not None:
515
+ default_vad_options.update(vad_options)
516
+
517
+ # Note: manually assigned vad_model has higher priority than vad_method!
518
+ if vad_model is not None:
519
+ print("Use manually assigned vad_model. vad_method is ignored.")
520
+ vad_model = vad_model
521
+ else:
522
+ if vad_method == "silero":
523
+ vad_model = Silero(**default_vad_options)
524
+ elif vad_method == "pyannote":
525
+ vad_model = Pyannote(
526
+ torch.device(device), use_auth_token=None, **default_vad_options
527
+ )
528
+ else:
529
+ raise ValueError(f"Invalid vad_method: {vad_method}")
530
+
531
+ return FasterWhisperPipeline(
532
+ model=model,
533
+ vad=vad_model,
534
+ options=default_asr_options,
535
+ tokenizer=tokenizer,
536
+ language=language,
537
+ suppress_numerals=suppress_numerals,
538
+ vad_params=default_vad_options,
539
+ )