Fix autocast device type for compatibility

This commit is contained in:
Yrom 2025-04-24 11:00:49 +08:00
parent bfd787eaa6
commit dd2b7dd820
No known key found for this signature in database

View File

@ -336,7 +336,7 @@ class IndexTTS:
self._set_gr_progress(0.2, "gpt inference speech...")
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
cond_mel_lengths=batch_cond_mel_lengths,
# text_lengths=text_len,
@ -369,7 +369,7 @@ class IndexTTS:
all_idxs.append(batch_sentences[i]["idx"])
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
@ -395,7 +395,7 @@ class IndexTTS:
tqdm_progress.update(len(items))
latent = torch.cat(items, dim=1)
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
with torch.amp.autocast(latent.device.type, enabled=self.dtype is not None, dtype=self.dtype):
m_start_time = time.perf_counter()
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
bigvgan_time += time.perf_counter() - m_start_time
@ -516,7 +516,7 @@ class IndexTTS:
m_start_time = time.perf_counter()
with torch.no_grad():
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
device=text_tokens.device),
@ -548,7 +548,7 @@ class IndexTTS:
m_start_time = time.perf_counter()
# latent, text_lens_out, code_lens_out = \
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
latent = \
self.gpt(auto_conditioning, text_tokens,
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,