You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

23 lines
647 B

  1. import time
  2. import torch
  3. import sys
  4. import subprocess
  5. argslist = list(sys.argv)[1:]
  6. num_gpus = torch.cuda.device_count()
  7. argslist.append('--n_gpus={}'.format(num_gpus))
  8. workers = []
  9. job_id = time.strftime("%Y_%m_%d-%H%M%S")
  10. argslist.append("--group_name=group_{}".format(job_id))
  11. for i in range(num_gpus):
  12. argslist.append('--rank={}'.format(i))
  13. stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i),
  14. "w")
  15. print(argslist)
  16. p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
  17. workers.append(p)
  18. argslist = argslist[:-1]
  19. for p in workers:
  20. p.wait()