diff --git a/indextts/infer.py b/indextts/infer.py index b2233b4..b71da6d 100644 --- a/indextts/infer.py +++ b/indextts/infer.py @@ -336,7 +336,7 @@ class IndexTTS: self._set_gr_progress(0.2, "gpt inference speech...") m_start_time = time.perf_counter() with torch.no_grad(): - with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens, cond_mel_lengths=batch_cond_mel_lengths, # text_lengths=text_len, @@ -369,7 +369,7 @@ class IndexTTS: all_idxs.append(batch_sentences[i]["idx"]) m_start_time = time.perf_counter() with torch.no_grad(): - with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): latent = \ self.gpt(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, @@ -395,7 +395,7 @@ class IndexTTS: tqdm_progress.update(len(items)) latent = torch.cat(items, dim=1) with torch.no_grad(): - with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + with torch.amp.autocast(latent.device.type, enabled=self.dtype is not None, dtype=self.dtype): m_start_time = time.perf_counter() wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2)) bigvgan_time += time.perf_counter() - m_start_time @@ -516,7 +516,7 @@ class IndexTTS: m_start_time = time.perf_counter() with torch.no_grad(): - with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): codes = self.gpt.inference_speech(auto_conditioning, text_tokens, cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], device=text_tokens.device), @@ -548,7 +548,7 @@ class IndexTTS: m_start_time = time.perf_counter() # latent, text_lens_out, code_lens_out = \ - with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype): + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): latent = \ self.gpt(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,