You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
2.2 KiB

  1. """ from https://github.com/keithito/tacotron """
  2. import re
  3. from text import cleaners
  4. from text.symbols import symbols
  5. # Mappings from symbol to numeric ID and vice versa:
  6. _symbol_to_id = {s: i for i, s in enumerate(symbols)}
  7. _id_to_symbol = {i: s for i, s in enumerate(symbols)}
  8. # Regular expression matching text enclosed in curly braces:
  9. _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
  10. def text_to_sequence(text, cleaner_names):
  11. '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
  12. The text can optionally have ARPAbet sequences enclosed in curly braces embedded
  13. in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
  14. Args:
  15. text: string to convert to a sequence
  16. cleaner_names: names of the cleaner functions to run the text through
  17. Returns:
  18. List of integers corresponding to the symbols in the text
  19. '''
  20. sequence = []
  21. # Check for curly braces and treat their contents as ARPAbet:
  22. while len(text):
  23. m = _curly_re.match(text)
  24. if not m:
  25. sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
  26. break
  27. sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
  28. sequence += _arpabet_to_sequence(m.group(2))
  29. text = m.group(3)
  30. # Append EOS token
  31. sequence.append(_symbol_to_id['~'])
  32. return sequence
  33. def sequence_to_text(sequence):
  34. '''Converts a sequence of IDs back to a string'''
  35. result = ''
  36. for symbol_id in sequence:
  37. if symbol_id in _id_to_symbol:
  38. s = _id_to_symbol[symbol_id]
  39. # Enclose ARPAbet back in curly braces:
  40. if len(s) > 1 and s[0] == '@':
  41. s = '{%s}' % s[1:]
  42. result += s
  43. return result.replace('}{', ' ')
  44. def _clean_text(text, cleaner_names):
  45. for name in cleaner_names:
  46. cleaner = getattr(cleaners, name)
  47. if not cleaner:
  48. raise Exception('Unknown cleaner: %s' % name)
  49. text = cleaner(text)
  50. return text
  51. def _symbols_to_sequence(symbols):
  52. return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
  53. def _arpabet_to_sequence(text):
  54. return _symbols_to_sequence(['@' + s for s in text.split()])
  55. def _should_keep_symbol(s):
  56. return s in _symbol_to_id and s is not '_' and s is not '~'