model = AE(input_shape=784).cuda(args.gpus)
model = torch.nn.parallel.DistributedDataParallel( model_sync, device_ids=[args.gpu], find_unused_parameters=True )
to
model = AE(input_shape=784).cuda(args.gpu)
model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True )