diff --git a/logger.py b/logger.py index 9b999ad..e6422c5 100644 --- a/logger.py +++ b/logger.py @@ -31,18 +31,18 @@ class Tacotron2Logger(SummaryWriter): self.add_image( "alignment", plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), - iteration) + iteration, dataformats='HWC') self.add_image( "mel_target", plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), - iteration) + iteration, dataformats='HWC') self.add_image( "mel_predicted", plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), - iteration) + iteration, dataformats='HWC') self.add_image( "gate", plot_gate_outputs_to_numpy( gate_targets[idx].data.cpu().numpy(), torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()), - iteration) + iteration, dataformats='HWC')