Fix autocast device type for compatibility
This commit is contained in:
parent
bfd787eaa6
commit
dd2b7dd820
@ -336,7 +336,7 @@ class IndexTTS:
|
||||
self._set_gr_progress(0.2, "gpt inference speech...")
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
temp_codes = self.gpt.inference_speech(batch_auto_conditioning, batch_text_tokens,
|
||||
cond_mel_lengths=batch_cond_mel_lengths,
|
||||
# text_lengths=text_len,
|
||||
@ -369,7 +369,7 @@ class IndexTTS:
|
||||
all_idxs.append(batch_sentences[i]["idx"])
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
@ -395,7 +395,7 @@ class IndexTTS:
|
||||
tqdm_progress.update(len(items))
|
||||
latent = torch.cat(items, dim=1)
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
with torch.amp.autocast(latent.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
m_start_time = time.perf_counter()
|
||||
wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2))
|
||||
bigvgan_time += time.perf_counter() - m_start_time
|
||||
@ -516,7 +516,7 @@ class IndexTTS:
|
||||
|
||||
m_start_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
codes = self.gpt.inference_speech(auto_conditioning, text_tokens,
|
||||
cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]],
|
||||
device=text_tokens.device),
|
||||
@ -548,7 +548,7 @@ class IndexTTS:
|
||||
|
||||
m_start_time = time.perf_counter()
|
||||
# latent, text_lens_out, code_lens_out = \
|
||||
with torch.amp.autocast(self.device, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype):
|
||||
latent = \
|
||||
self.gpt(auto_conditioning, text_tokens,
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user