From 7d8644fb3150f71100aa15b8c177f0a0bc204ae3 Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 1 Aug 2023 22:14:26 +0000 Subject: [PATCH] add flag for device --- examples/inverse_folding/sample_sequences.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/inverse_folding/sample_sequences.py b/examples/inverse_folding/sample_sequences.py index 3a126930..531a4017 100644 --- a/examples/inverse_folding/sample_sequences.py +++ b/examples/inverse_folding/sample_sequences.py @@ -31,7 +31,7 @@ def sample_seq_singlechain(model, alphabet, args): with open(args.outpath, 'w') as f: for i in range(args.num_samples): print(f'\nSampling.. ({i+1} of {args.num_samples})') - sampled_seq = model.sample(coords, temperature=args.temperature, device=torch.device('cuda')) + sampled_seq = model.sample(coords, temperature=args.temperature, device=torch.device(args.device)) print('Sampled sequence:') print(sampled_seq) f.write(f'>sampled_seq_{i+1}\n') @@ -97,6 +97,11 @@ def main(): help='number of sequences to sample', default=1, ) + parser.add_argument( + '--device', type=str, + help='torch device to use', + default='cuda', + ) parser.set_defaults(multichain_backbone=False) parser.add_argument( '--multichain-backbone', action='store_true',