去除了torchaudio兼容不行的问题

This commit is contained in:
cloyir 2026-02-15 03:35:39 +08:00
parent d8e5f90086
commit 219b5fc93c

View File

@ -6,8 +6,9 @@ import json
import re
import time
import librosa
import soundfile as sf
import numpy as np
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import warnings
@ -353,6 +354,41 @@ class IndexTTS2:
return emo_vector
def _compute_fbank_librosa(self, audio, n_mels=80, sr=16000):
"""
使用 librosa 计算 fbank 特征替代 torchaudio.compliance.kaldi.fbank
兼容 ARM 架构
"""
# 将 torch tensor 转为 numpy
device = audio.device
audio_np = audio.squeeze().cpu().numpy()
# 计算 mel spectrogram使用与 Kaldi 类似的参数)
# n_fft=512, hop_length=160, win_length=400 对应 Kaldi 默认的 25ms window, 10ms shift @ 16kHz
mel_spec = librosa.feature.melspectrogram(
y=audio_np,
sr=sr,
n_fft=512,
hop_length=160,
win_length=400,
n_mels=n_mels,
fmin=20,
fmax=7600,
center=False,
power=2.0, # 功率谱
)
# 转回 torch tensor 并移到原设备
mel_spec = torch.from_numpy(mel_spec).to(device)
# 应用 log 压缩Kaldi fbank 默认行为)
mel_spec = torch.log(mel_spec + 1e-10)
# 转置为 (time, n_mels) 格式,与 Kaldi fbank 输出一致
mel_spec = mel_spec.transpose(0, 1)
return mel_spec
# 原始推理模式
def infer(self, spk_audio_prompt, text, output_path,
emo_audio_prompt=None, emo_alpha=1.0,
@ -433,8 +469,12 @@ class IndexTTS2:
self.cache_mel = None
torch.cuda.empty_cache()
audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
# 使用 librosa 进行重采样
audio_np = audio.squeeze().numpy()
audio_22k_np = librosa.resample(audio_np, orig_sr=sr, target_sr=22050)
audio_16k_np = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
audio_22k = torch.from_numpy(audio_22k_np).unsqueeze(0)
audio_16k = torch.from_numpy(audio_16k_np).unsqueeze(0)
inputs = self.extract_features(audio_16k, sampling_rate=16000, return_tensors="pt")
input_features = inputs["input_features"]
@ -446,10 +486,8 @@ class IndexTTS2:
_, S_ref = self.semantic_codec.quantize(spk_cond_emb)
ref_mel = self.mel_fn(audio_22k.to(spk_cond_emb.device).float())
ref_target_lengths = torch.LongTensor([ref_mel.size(2)]).to(ref_mel.device)
feat = torchaudio.compliance.kaldi.fbank(audio_16k.to(ref_mel.device),
num_mel_bins=80,
dither=0,
sample_frequency=16000)
# 使用 librosa 计算 fbank 特征(替代 torchaudio.compliance.kaldi.fbank
feat = self._compute_fbank_librosa(audio_16k.to(ref_mel.device), n_mels=80, sr=16000)
feat = feat - feat.mean(dim=0, keepdim=True) # feat2另外一个滤波器能量组特征[922, 80]
style = self.campplus_model(feat.unsqueeze(0)) # 参考音频的全局style2[1,192]
@ -694,7 +732,9 @@ class IndexTTS2:
print(">> remove old wav file:", output_path)
if os.path.dirname(output_path) != "":
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
# 使用 soundfile 保存音频(替代 torchaudio.save
wav_np = wav.squeeze().numpy().astype(np.int16)
sf.write(output_path, wav_np, sampling_rate, subtype='PCM_16')
print(">> wav file saved to:", output_path)
if stream_return:
return None