Commit eb363053 authored by hgarrereyn's avatar hgarrereyn
Browse files

add cpu/gpu and full smiles output

parent 670da606
......@@ -6,3 +6,7 @@ data/**
!data/README.md
.DS_Store
.store/
dist/**
build/**
......@@ -7,9 +7,7 @@ DeepFrag is a machine learning model for fragment-based lead optimization. In th
If you use DeepFrag in your research, please cite as:
```
Green, H., Koes, D. R., & Durrant, J. D. (2021). DeepFrag: a deep convolutional neural network for fragment-based lead optimization. Chemical Science.
```
```tex
@article{green2021deepfrag,
......@@ -64,11 +62,14 @@ To remove a fragment, you specify a second atom that is contained in the fragmen
By default, DeepFrag will print a list of fragment predictions to stdout similar to the [Browser App](https://durrantlab.pitt.edu/deepfrag/).
- `--out <out.csv>`: Save predictions in CSV format to `out.csv`.
- `--out <out.csv>`: Save predictions in CSV format to `out.csv`. Each line contains the fragment rank, score and SMILES string.
## Miscellaneous (optional)
- `--full`: Generate SMILES strings with the full ligand structure instead of just the fragment.
- `--cpu/--gpu`: DeepFrag will attempt to infer if a Cuda GPU is available and fallback to the CPU if it is not. You can set either the `--cpu` or `--gpu` flag to explicitly specify the target device.
- `--num_grids <num>`: Number of grid rotations to use. Using more will take longer but produce a more stable prediction. (Default: 4)
- `--top_k <k>`: Number of predictions to print in stdout. Use -1 to display all. (Default: 25)
# Reproduce Results
......
......@@ -207,13 +207,13 @@ def get_structures(args):
parent_coords = util.get_coords(lig)
parent_types = np.array(util.get_types(lig)).reshape((-1,1))
return (rec_coords, rec_types, parent_coords, parent_types, conn)
return (rec_coords, rec_types, parent_coords, parent_types, conn, lig)
def get_model(args):
def get_model(args, device):
"""Load a pre-trained DeepFrag model."""
print('[*] Loading model ... ', end='')
model = LeadoptModel.load(str(get_model_path() / 'final_model'), device='cpu')
model = LeadoptModel.load(str(get_model_path() / 'final_model'), device=('cuda' if device == 'gpu' else device))
print('done.')
return model
......@@ -233,7 +233,26 @@ def get_fingerprints(args):
return (f_smiles, f_fingerprints)
def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, parent_types, conn):
def get_target_device(args) -> str:
"""Infer the target device or use the argument overrides."""
device = 'gpu' if torch.cuda.device_count() > 0 else 'cpu'
if args.cpu:
if device == 'gpu':
print('[*] Warning: GPU is available but running on CPU due to --cpu flag')
device = 'cpu'
elif args.gpu:
if device == 'cpu':
print('[*] Error: No CUDA-enabled GPU was found. Exiting due to --gpu flag. You can run on the CPU instead with the --cpu flag.')
exit(-1)
device = 'gpu'
print('[*] Running on device: %s' % device)
return device
def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, parent_types, conn, device):
start = time.time()
print('[*] Generating grids ... ', end='', flush=True)
......@@ -248,7 +267,7 @@ def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, paren
point_radius=model_args['point_radius'],
point_type=model_args['point_type'],
acc_type=model_args['acc_type'],
cpu=True
cpu=(device == 'cpu')
)
print('done.')
end = time.time()
......@@ -276,44 +295,80 @@ def get_predictions(model, batch, f_smiles, f_fingerprints):
dist = list(dist.numpy())
scores = list(zip(f_smiles, dist))
scores = sorted(scores, key=lambda x:x[1], reverse=True)
scores = [(a.decode('ascii'), b) for a,b in scores]
return scores
def gen_output(args, scores):
if args.top_k != -1:
scores = scores[:args.top_k]
if args.out is None:
# Write results to stdout.
print('%4s %8s %s' % ('#', 'Score', 'Fragment'))
print('%4s %8s %s' % ('#', 'Score', 'SMILES'))
for i in range(len(scores)):
smi, score = scores[i]
print('%4d %8f %s' % (i+1, score, smi.decode('ascii')))
print('%4d %8f %s' % (i+1, score, smi))
else:
# Write csv output.
csv = 'Rank,Fragment SMILES,Score\n'
csv = 'Rank,SMILES,Score\n'
for i in range(len(scores)):
smi, score = scores[i]
csv += '%d,%s,%f\n' % (
i+1, smi.decode('ascii'), score
i+1, smi, score
)
open(args.out, 'w').write(csv)
print('[*] Wrote output to %s' % args.out)
def fuse(lig, frag):
merged = Chem.RWMol(Chem.CombineMols(lig, frag))
conn_atoms = [a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0]
neighbors = [merged.GetAtomWithIdx(x).GetNeighbors()[0].GetIdx() for x in conn_atoms]
bond = merged.AddBond(neighbors[0], neighbors[1], Chem.rdchem.BondType.SINGLE)
merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0])
merged.RemoveAtom([a.GetIdx() for a in merged.GetAtoms() if a.GetAtomicNum() == 0][0])
Chem.SanitizeMol(merged)
return merged
def fuse_fragments(lig, conn, scores):
new_sc = []
for smi, score in scores:
try:
frag = Chem.MolFromSmiles(smi)
fused = fuse(Chem.Mol(lig), frag)
new_sc.append((Chem.MolToSmiles(fused, False), score))
except:
print('[*] Error: couldn\'t process mol.')
new_sc.append(('<err>', score))
return new_sc
def run(args):
model = get_model(args)
device = get_target_device(args)
model = get_model(args, device)
f_smiles, f_fingerprints = get_fingerprints(args)
rec_coords, rec_types, parent_coords, parent_types, conn = get_structures(args)
rec_coords, rec_types, parent_coords, parent_types, conn, lig = get_structures(args)
batch = generate_grids(args, model._args, rec_coords, rec_types,
parent_coords, parent_types, conn)
parent_coords, parent_types, conn, device)
scores = get_predictions(model, batch, f_smiles, f_fingerprints)
if args.top_k != -1:
scores = scores[:args.top_k]
if args.full:
scores = fuse_fragments(lig, conn, scores)
gen_output(args, scores)
......@@ -341,16 +396,26 @@ def main():
parser.add_argument('--rname', type=str, help='Removal point atom name.')
# Misc
parser.add_argument('--num_grids', type=int, default=4, help='Number of grid rotations.')
parser.add_argument('-k', '--top_k', type=int, default=25, help='Number of results to show. Set to -1 to show all.')
parser.add_argument('--out', type=str, help='Path to output CSV file.')
parser.add_argument('--full', action='store_true', default=False,
help='Print the full (fused) ligand structure.')
parser.add_argument('--num_grids', type=int, default=4,
help='Number of grid rotations.')
parser.add_argument('--top_k', type=int, default=25,
help='Number of results to show. Set to -1 to show all.')
parser.add_argument('--out', type=str,
help='Path to output CSV file.')
parser.add_argument('--cpu', action='store_true', default=False,
help='Use the CPU for grid generation and predictions.')
parser.add_argument('--gpu', action='store_true', default=False,
help='Use a (CUDA-capable) GPU for grid generation and predictions.')
args = parser.parse_args()
groupings = [
([('receptor', 'ligand'), ('pdb', 'resnum')], True),
([('cx', 'cy', 'cz'), ('cname',)], True),
([('rx', 'ry', 'rz'), ('rname',)], False)
([('rx', 'ry', 'rz'), ('rname',)], False),
([('cpu',), ('gpu',)], False)
]
for grp, req in groupings:
......@@ -358,7 +423,7 @@ def main():
complete = 0
for subset in grp:
res = [getattr(args, name) is not None for name in subset]
res = [not (getattr(args, name) in [None, False]) for name in subset]
partial.append(any(res) and not all(res))
complete += int(all(res))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment