# DeepFrag
This repository contains code for machine learning based lead optimization.
DeepFrag is a machine learning model for fragment-based lead optimization. In this repository, you will find code to train the model and code to run inference using a pre-trained model.
# Examples
# Citation
See [this Colab]( for an interactive example of how to use a pre-trained DeepFrag model to generate predictions.
If you use DeepFrag in your research, please cite as:
Green, H., Koes, D. R., & Durrant, J. D. (2021). DeepFrag: a deep convolutional neural network for fragment-based lead optimization. Chemical Science.
title={DeepFrag: a deep convolutional neural network for fragment-based lead optimization},
author={Green, Harrison and Koes, David Ryan and Durrant, Jacob D},
journal={Chemical Science},
publisher={Royal Society of Chemistry}
# Usage
There are three ways to use DeepFrag:
1. **DeepFrag Browser App**: We have released a free, open-source browser app for DeepFrag that requires no setup and does not transmit any structures to a remote server.
- View the online version at [](
- See the code at [](
2. **DeepFrag CLI**: In this repository we have included a `` script that can perform common prediction tasks using the API.
- See the `DeepFrag CLI` section below
3. **DeepFrag API**: For custom tasks or fine-grained control over predictions, you can invoke the DeepFrag API directly and interface with the raw data structures and the PyTorch model. We have created an example Google Colab (Jupyter notebook) that demonstrates how to perform manual predictions.
- See the interactive [Colab](
# DeepFrag CLI
The DeepFrag CLI is invoked by running `python3` in this repository. The CLI requires a pre-trained model and the fragment library to run. You will be prompted to download both when you first run the CLI and these will be saved in the `./.store` directory.
## Structure (specify exactly one)
The input structures are specified using either a manual receptor and ligand pdb or by specifying a pdb id and the ligand residue number.
- `--receptor <rec.pdb> --ligand <lig.pdb>`
- `--pdb <pdbid> --resnum <resnum>`
## Connection Point (specify exactly one)
DeepFrag will predict new fragments that connect to the _connection point_ via a single bond. You must specify the connection point atom using one of the following:
- `--cname <name>`: Specify the connection point by atom name (e.g. `C3`, `N5`, `O2`, ...).
- `--cx <x> --cy <y> --cz <z>`: Specify the connection point by atomic coordinate. DeepFrag will find the closest atom to this point.
## Fragment Removal (optional) (specify exactly one)
If you are using DeepFrag for fragment _replacement_, you must first remove the original fragment from the ligand structure. You can either do this by hand, e.g. editing the PDB, or DeepFrag can do this for you by specifying _which_ fragment should be removed.
_Note: predicting fragments in place of hydrogen atoms (e.g. protons) does not require any fragment removal since hydrogen atoms are ignored by the model._
To remove a fragment, you specify a second atom that is contained in the fragment. Like the connection point, you can either use the atom name or the atom coordinate.
- `--rname <name>`: Specify the connection point by atom name (e.g. `C3`, `N5`, `O2`, ...).
- `--rx <x> --ry <y> --rz <z>`: Specify the connection point by atomic coordinate. DeepFrag will find the closest atom to this point.
## Output (optional)
By default, DeepFrag will print a list of fragment predictions to stdout similar to the [Browser App](
- `--out <out.csv>`: Save predictions in CSV format to `out.csv`.
## Miscellaneous (optional)
- `--cpu/--gpu`: DeepFrag will attempt to infer if a Cuda GPU is available and fallback to the CPU if it is not. You can set either the `--cpu` or `--gpu` flag to explicitly specify the target device.
# Reproduce Results
You can use the DeepFrag CLI to reproduce the highlighted results from the main manuscript:
## 1. Fragment replacement
To replace fragments, specify the connection point (`cname` or `cx/cy/cz`) and specify a second atom that is contained in the fragment (`rname` or `rx/ry/rz`).
# Fig. 3: (2XP9) H. sapiens peptidyl-prolyl cis–trans isomerase NIMA-interacting 1 (HsPin1p)
# Carboxylate A
$ python3 --pdb 2xp9 --resnum 1165 --cname C10 --rname C12
# Phenyl B
$ python3 --pdb 2xp9 --resnum 1165 --cname C1 --rname C2
# Phenyl C
$ python3 --pdb 2xp9 --resnum 1165 --cname C18 --rname C19
# Fig. 4A: (6QZ8) Protein myeloid cell leukemia1 (Mcl-1)
# Carboxylate group interacting with R263
$ python3 --pdb 6qz8 --resnum 401 --cname C12 --rname C14
# Ethyl group
$ python3 --pdb 6qz8 --resnum 401 --cname C6 --rname C10
# Methyl group
$ python3 --pdb 6qz8 --resnum 401 --cname C25 --rname C30
# Chlorine atom
$ python3 --pdb 6qz8 --resnum 401 --cname C28 --rname CL
# Fig. 4B: (1X38) Family GH3 b-D-glucan glucohydrolase (barley)
# Hydroxyl group interacting with R158 and D285
$ python3 --pdb 1x38 --resnum 1001 --cname C2B --rname O2B
# Phenyl group interacting with W286 and W434
$ python3 --pdb 1x38 --resnum 1001 --cname C7B --rname C1
# Fig. 4C: (4FOW) NanB sialidase (Streptococcus pneumoniae)
# Amino group
$ python3 --pdb 4fow --resnum 701 --cname CAE --rname NAA
## 2. Fragment addition
For fragment addition, you only need to specify the atom connection point (`cname` or `cx/cy/cz`). In this case, DeepFrag will implicily replace a valent hydrogen.
# Fig. 5: Ligands targeting the SARS-CoV-2 main protease (MPro)
# 5A: (5RGH) Extension on Z1619978933
$ python3 --pdb 5rgh --resnum 404 --cname C09
# 5B: (5R81) Extension on Z1367324110
$ python3 --pdb 5r81 --resnum 1001 --cname C07
# Overview
import argparse
import functools
import os
import pathlib
import shutil
import time
from typing import Tuple
import zipfile
import requests
from import tqdm
import h5py
import numpy as np
import rdkit.Chem.AllChem as Chem
import torch
import prody
from leadopt.model_conf import LeadoptModel, REC_TYPER, LIG_TYPER, DIST_FN
from leadopt import util, grid_util
USER_DIR = './.store'
PDB_CACHE = 'pdb_cache'
def download_remote(url, path, compression=None):
r = requests.get(url, stream=True, allow_redirects=True)
if r.status_code != 200:
print(f'Can\'t access {url}')
file_size = int(r.headers.get('Content-Length', 0)) = functools.partial(, decode_content=True)
with tqdm.wrapattr(r.raw, 'read', total=file_size, desc='Downloading') as r_raw:
with'wb') as f:
shutil.copyfileobj(r_raw, f)
if compression is not None:
shutil.move(str(path), str(path) + '.tmp')
shutil.unpack_archive(str(path) + '.tmp', str(path), format=compression)
def get_deepfrag_user_dir() -> pathlib.Path:
user_dir = pathlib.Path(os.path.realpath(__file__)).parent / USER_DIR
os.makedirs(str(user_dir), exist_ok=True)
return user_dir
def get_model_path():
return get_deepfrag_user_dir() / 'model'
def get_fingerprints_path():
return get_deepfrag_user_dir() / 'fingerprints.h5'
def ensure_cli_data():
model_path = get_model_path()
fingerprints_path = get_fingerprints_path()
if not os.path.exists(str(model_path)):
r = input('Pre-trained DeepFrag model not found, download it now? (5.8 MB) [Y/n]: ')
if r.lower() == 'n':
print(f'Saving to {model_path}...')
download_remote(MODEL_DOWNLOAD, model_path, compression='zip')
if not os.path.exists(str(fingerprints_path)):
r = input('Fingerprint library not found, download it now? (11 MB) [Y/n]: ')
if r.lower() == 'n':
print(f'Saving to {fingerprints_path}...')
download_remote(FINGERPRINTS_DOWNLOAD, fingerprints_path, compression=None)
def download_pdb(pdb_id, path):
download_remote(RCSB_DOWNLOAD % pdb_id, path, compression=None)
def load_pdb(pdb_id, resnum):
pdb_id = pdb_id.upper()
assert all([x in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' for x in pdb_id])
# Check pdb cache
pdb_dir = get_deepfrag_user_dir() / PDB_CACHE / pdb_id
complex_path = pdb_dir / 'complex.pdb'
rec_path = pdb_dir / 'receptor.pdb'
lig_path = pdb_dir / 'ligand.pdb'
os.makedirs(str(pdb_dir), exist_ok=True)
if not os.path.exists(complex_path):
download_pdb(pdb_id, complex_path)
m = prody.parsePDB(str(complex_path))
rec ='not (nucleic or hetatm) and not water')
lig ='resnum %d' % resnum)
if lig is None:
print('[!] Error could not find ligand with resnum: %d' % resnum)
prody.writePDB(str(rec_path), rec)
prody.writePDB(str(lig_path), lig)
return (str(rec_path), str(lig_path))
def get_structure_paths(args) -> Tuple[str, str]:
"""Get structure paths specified by the command line args.
Returns (rec_path, lig_path)
if args.receptor is not None and args.ligand is not None:
return (args.receptor, args.ligand)
elif args.pdb is not None and args.resnum is not None:
return load_pdb(args.pdb, args.resnum)
raise NotImplementedError()
def preprocess_ligand(lig, conn, rvec):
Remove the fragment from lig connected via the atom at conn and containing
the atom at rvec.
# Generate all fragments.
frags = util.generate_fragments(lig)
for parent, frag in frags:
cidx = [a for a in frag.GetAtoms() if a.GetAtomicNum() == 0][0].GetIdx()
vec = frag.GetConformer().GetAtomPosition(cidx)
c_vec = np.array([vec.x, vec.y, vec.z])
# Check connection point.
if np.linalg.norm(c_vec - conn) < 1e-3:
# Check removal point.
frag_pos = frag.GetConformer().GetPositions()
min_dist = np.min(np.sum((frag_pos - rvec) ** 2, axis=1))
if min_dist < 1e-3:
# Found fragment.
print('[*] Removing fragment with %d atoms (%s)' % (
frag_pos.shape[0] - 1, Chem.MolToSmiles(frag, False)))
return parent
print('[!] Could not find a suitable fragment to remove.')
def lookup_atom_name(lig_path, name):
"""Try to look up an atom by name. Returns the coordinate of the atom if
p = prody.parsePDB(lig_path)
p ='name %s' % name)
if p is None:
print('[!] Error: no atom with name "%s" in ligand' % name)
elif len(p) > 1:
print('[!] Error: multiple atoms with name "%s" in ligand' % name)
return p.getCoords()[0]
def get_structures(args):
rec_path, lig_path = get_structure_paths(args)
print(f'[*] Loading receptor: {rec_path} ... ', end='')
rec_coords, rec_types = util.load_receptor_ob(rec_path)
print(f'[*] Loading ligand: {lig_path} ... ', end='')
lig = Chem.MolFromPDBFile(lig_path)
conn = None
if is not None and is not None and is not None:
conn = np.array([float(, float(, float(])
elif args.cname is not None:
conn = lookup_atom_name(lig_path, args.cname)
raise NotImplementedError()
rvec = None
if args.rx is not None and args.ry is not None and args.rz is not None:
rvec = np.array([float(args.rx), float(args.ry), float(args.rz)])
elif args.rname is not None:
rvec = lookup_atom_name(lig_path, args.rname)
if rvec is not None:
lig = preprocess_ligand(lig, conn, rvec)
parent_coords = util.get_coords(lig)
parent_types = np.array(util.get_types(lig)).reshape((-1,1))
return (rec_coords, rec_types, parent_coords, parent_types, conn)
def get_model(args):
"""Load a pre-trained DeepFrag model."""
print('[*] Loading model ... ', end='')
model = LeadoptModel.load(str(get_model_path() / 'final_model'), device='cpu')
return model
def get_fingerprints(args):
"""Load the fingerprint library.
Returns (smiles, fingerprints).
f_smiles = None
f_fingerprints = None
print('[*] Loading fingerprint library ... ', end='')
with h5py.File(str(get_fingerprints_path()), 'r') as f:
f_smiles = f['smiles'][()]
f_fingerprints = f['fingerprints'][()].astype(np.float)
return (f_smiles, f_fingerprints)
def generate_grids(args, model_args, rec_coords, rec_types, parent_coords, parent_types, conn):
start = time.time()
print('[*] Generating grids ... ', end='', flush=True)
batch = grid_util.get_raw_batch(
rec_coords, rec_types, parent_coords, parent_types,
end = time.time()
print(f'[*] Generated grids in {end-start:.3f} seconds.')
return batch
def get_predictions(model, batch, f_smiles, f_fingerprints):
start = time.time()
pred = model.predict(torch.tensor(batch).float()).cpu().numpy()
end = time.time()
print(f'[*] Generated prediction in {end-start} seconds.')
avg_fp = np.mean(pred, axis=0)
dist_fn = DIST_FN[model._args['dist_fn']]
# The distance functions are implemented in pytorch so we need to convert our
# numpy arrays to a torch Tensor.
dist = 1 - dist_fn(
# Pair smiles strings and distances.
dist = list(dist.numpy())
scores = list(zip(f_smiles, dist))
scores = sorted(scores, key=lambda x:x[1], reverse=True)
return scores
def gen_output(args, scores):
if args.top_k != -1:
scores = scores[:args.top_k]
if args.out is None:
# Write results to stdout.
print('%4s %8s %s' % ('#', 'Score', 'Fragment'))
for i in range(len(scores)):
smi, score = scores[i]
print('%4d %8f %s' % (i+1, score, smi.decode('ascii')))
# Write csv output.
csv = 'Rank,Fragment SMILES,Score\n'
for i in range(len(scores)):
smi, score = scores[i]
csv += '%d,%s,%f\n' % (
i+1, smi.decode('ascii'), score
open(args.out, 'w').write(csv)
print('[*] Wrote output to %s' % args.out)
def run(args):
model = get_model(args)
f_smiles, f_fingerprints = get_fingerprints(args)
rec_coords, rec_types, parent_coords, parent_types, conn = get_structures(args)
batch = generate_grids(args, model._args, rec_coords, rec_types,
parent_coords, parent_types, conn)
scores = get_predictions(model, batch, f_smiles, f_fingerprints)
gen_output(args, scores)
def main():
parser = argparse.ArgumentParser()
# Structure
parser.add_argument('--receptor', help='Path to receptor structure.')
parser.add_argument('--ligand', help='Path to ligand structure.')
parser.add_argument('--pdb', help='PDB ID to download.')
parser.add_argument('--resnum', type=int, help='Residue number of ligand.')
# Connection point
parser.add_argument('--cx', type=int, help='Connection point x coordinate.')
parser.add_argument('--cy', type=int, help='Connection point y coordinate.')
parser.add_argument('--cz', type=int, help='Connection point z coordinate.')
parser.add_argument('--cname', type=str, help='Connection point atom name.')
# Removal point
parser.add_argument('--rx', type=int, help='Removal point x coordinate.')
parser.add_argument('--ry', type=int, help='Removal point y coordinate.')
parser.add_argument('--rz', type=int, help='Removal point z coordinate.')
parser.add_argument('--rname', type=str, help='Removal point atom name.')
# Misc
parser.add_argument('--num_grids', type=int, default=4, help='Number of grid rotations.')
parser.add_argument('-k', '--top_k', type=int, default=25, help='Number of results to show. Set to -1 to show all.')
parser.add_argument('--out', type=str, help='Path to output CSV file.')
args = parser.parse_args()
groupings = [
([('receptor', 'ligand'), ('pdb', 'resnum')], True),
([('cx', 'cy', 'cz'), ('cname',)], True),
([('rx', 'ry', 'rz'), ('rname',)], False)
for grp, req in groupings:
partial = []
complete = 0
for subset in grp:
res = [getattr(args, name) is not None for name in subset]
partial.append(any(res) and not all(res))
complete += int(all(res))
if any(partial) or complete > 1 or (complete != 1 and req):
# Invalid arg combination.
print('Invalid arguments, must specify exactly one of the following combinations:')
for subset in grp:
print('\t%s' % ', '.join(['--' + x for x in subset]))
if __name__=='__main__':
