You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
1.7 KiB

  1. import matplotlib
  2. matplotlib.use("Agg")
  3. import matplotlib.pylab as plt
  4. import numpy as np
  5. def save_figure_to_numpy(fig):
  6. # save it to a numpy array.
  7. data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  8. data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
  9. return data
  10. def plot_alignment_to_numpy(alignment, info=None):
  11. fig, ax = plt.subplots(figsize=(6, 4))
  12. im = ax.imshow(alignment, aspect='auto', origin='lower',
  13. interpolation='none')
  14. fig.colorbar(im, ax=ax)
  15. xlabel = 'Decoder timestep'
  16. if info is not None:
  17. xlabel += '\n\n' + info
  18. plt.xlabel(xlabel)
  19. plt.ylabel('Encoder timestep')
  20. plt.tight_layout()
  21. fig.canvas.draw()
  22. data = save_figure_to_numpy(fig)
  23. plt.close()
  24. return data
  25. def plot_spectrogram_to_numpy(spectrogram):
  26. fig, ax = plt.subplots(figsize=(12, 3))
  27. im = ax.imshow(spectrogram, aspect="auto", origin="lower",
  28. interpolation='none')
  29. plt.colorbar(im, ax=ax)
  30. plt.xlabel("Frames")
  31. plt.ylabel("Channels")
  32. plt.tight_layout()
  33. fig.canvas.draw()
  34. data = save_figure_to_numpy(fig)
  35. plt.close()
  36. return data
  37. def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
  38. fig, ax = plt.subplots(figsize=(12, 3))
  39. ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5,
  40. color='green', marker='+', s=1, label='target')
  41. ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5,
  42. color='red', marker='.', s=1, label='predicted')
  43. plt.xlabel("Frames (Green target, Red predicted)")
  44. plt.ylabel("Gate State")
  45. plt.tight_layout()
  46. fig.canvas.draw()
  47. data = save_figure_to_numpy(fig)
  48. plt.close()
  49. return data