去除了torchaudio兼容不行的问题
This commit is contained in:
parent
d8e5f90086
commit
219b5fc93c
@ -6,8 +6,9 @@ import json
|
||||
import re
|
||||
import time
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import warnings
|
||||
@ -353,6 +354,41 @@ class IndexTTS2:
|
||||
|
||||
return emo_vector
|
||||
|
||||
def _compute_fbank_librosa(self, audio, n_mels=80, sr=16000):
|
||||
"""
|
||||
使用 librosa 计算 fbank 特征(替代 torchaudio.compliance.kaldi.fbank)
|
||||
兼容 ARM 架构
|
||||
"""
|
||||
# 将 torch tensor 转为 numpy
|
||||
device = audio.device
|
||||
audio_np = audio.squeeze().cpu().numpy()
|
||||
|
||||
# 计算 mel spectrogram(使用与 Kaldi 类似的参数)
|
||||
# n_fft=512, hop_length=160, win_length=400 对应 Kaldi 默认的 25ms window, 10ms shift @ 16kHz
|
||||
mel_spec = librosa.feature.melspectrogram(
|
||||
y=audio_np,
|
||||
sr=sr,
|
||||
n_fft=512,
|
||||
hop_length=160,
|
||||
win_length=400,
|
||||
n_mels=n_mels,
|
||||
fmin=20,
|
||||
fmax=7600,
|
||||
center=False,
|
||||
power=2.0, # 功率谱
|
||||
)
|
||||
|
||||
# 转回 torch tensor 并移到原设备
|
||||
mel_spec = torch.from_numpy(mel_spec).to(device)
|
||||
|
||||
# 应用 log 压缩(Kaldi fbank 默认行为)
|
||||
mel_spec = torch.log(mel_spec + 1e-10)
|
||||
|
||||
# 转置为 (time, n_mels) 格式,与 Kaldi fbank 输出一致
|
||||
mel_spec = mel_spec.transpose(0, 1)
|
||||
|
||||
return mel_spec
|
||||
|
||||
# 原始推理模式
|
||||
def infer(self, spk_audio_prompt, text, output_path,
|
||||
emo_audio_prompt=None, emo_alpha=1.0,
|
||||
@ -433,8 +469,12 @@ class IndexTTS2:
|
||||
self.cache_mel = None
|
||||
torch.cuda.empty_cache()
|
||||
audio,sr = self._load_and_cut_audio(spk_audio_prompt,15,verbose)
|
||||
audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio)
|
||||
audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio)
|
||||
# 使用 librosa 进行重采样
|
||||
audio_np = audio.squeeze().numpy()
|
||||
audio_22k_np = librosa.resample(audio_np, orig_sr=sr, target_sr=22050)
|
||||
audio_16k_np = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
||||
audio_22k = torch.from_numpy(audio_22k_np).unsqueeze(0)
|
||||
audio_16k = torch.from_numpy(audio_16k_np).unsqueeze(0)
|
||||
|
||||
inputs = self.extract_features(audio_16k, sampling_rate=16000, return_tensors="pt")
|
||||
input_features = inputs["input_features"]
|
||||
@ -446,10 +486,8 @@ class IndexTTS2:
|
||||
_, S_ref = self.semantic_codec.quantize(spk_cond_emb)
|
||||
ref_mel = self.mel_fn(audio_22k.to(spk_cond_emb.device).float())
|
||||
ref_target_lengths = torch.LongTensor([ref_mel.size(2)]).to(ref_mel.device)
|
||||
feat = torchaudio.compliance.kaldi.fbank(audio_16k.to(ref_mel.device),
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
# 使用 librosa 计算 fbank 特征(替代 torchaudio.compliance.kaldi.fbank)
|
||||
feat = self._compute_fbank_librosa(audio_16k.to(ref_mel.device), n_mels=80, sr=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True) # feat2另外一个滤波器能量组特征[922, 80]
|
||||
style = self.campplus_model(feat.unsqueeze(0)) # 参考音频的全局style2[1,192]
|
||||
|
||||
@ -694,7 +732,9 @@ class IndexTTS2:
|
||||
print(">> remove old wav file:", output_path)
|
||||
if os.path.dirname(output_path) != "":
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
torchaudio.save(output_path, wav.type(torch.int16), sampling_rate)
|
||||
# 使用 soundfile 保存音频(替代 torchaudio.save)
|
||||
wav_np = wav.squeeze().numpy().astype(np.int16)
|
||||
sf.write(output_path, wav_np, sampling_rate, subtype='PCM_16')
|
||||
print(">> wav file saved to:", output_path)
|
||||
if stream_return:
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user