Commit 44579ec9 authored by jdurrant's avatar jdurrant
Browse files

Merge branch 'cpu_gridify' into 'main'

add cpu_gridify option, fix config import, remove openbabel dependency

See merge request !1
parents 4d81a065 a54e9656
......@@ -209,17 +209,194 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
numba.cuda.atomic.max(grid, idx, val)
@numba.jit(nopython=True)
def cpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot,
point_radius, point_type, acc_type
):
"""Adds atoms to the grid in a GPU kernel.
This kernel converts atom coordinate information to 3d voxel information.
Each GPU thread is responsible for one specific grid point. This function
receives a list of atomic coordinates and atom layers and simply iterates
over the list to find nearby atoms and add their effect.
Voxel information is stored in a 5D tensor of type: BxTxNxNxN where:
B = batch size
T = number of atom types (receptor + ligand)
N = grid width (in gridpoints)
Each invocation of this function will write information to a specific batch
index specified by batch_idx. Additionally, the layer_offset parameter can
be set to specify a fixed offset to add to each atom_layer item.
How it works:
1. Each GPU thread controls a single gridpoint. This gridpoint coordinate
is translated to a "real world" coordinate by applying rotation and
translation vectors.
2. Each thread iterates over the list of atoms and checks for atoms within
a threshold to add to the grid.
Args:
grid: DeviceNDArray tensor where grid information is stored
atom_num: number of atoms
atom_coords: array containing (x,y,z) atom coordinates
atom_mask: uint32 array of size atom_num containing a destination
layer bitmask (i.e. if bit k is set, write atom to index k)
layer_offset: a fixed ofset added to each atom layer index
batch_idx: index specifiying which batch to write information to
width: number of grid points in each dimension
res: distance between neighboring grid points in angstroms
(1 == gridpoint every angstrom)
(0.5 == gridpoint every half angstrom, e.g. tighter grid)
center: (x,y,z) coordinate of grid center
rot: (x,y,z,y) rotation quaternion
"""
# x,y,z = numba.cuda.grid(3)
for x in range(width):
for y in range(width):
for z in range(width):
# center around origin
tx = x - (width/2)
ty = y - (width/2)
tz = z - (width/2)
# scale by resolution
tx = tx * res
ty = ty * res
tz = tz * res
# apply rotation vector
aw = rot[0]
ax = rot[1]
ay = rot[2]
az = rot[3]
bw = 0
bx = tx
by = ty
bz = tz
# multiply by rotation vector
cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz)
cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by)
cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz)
cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx)
# multiply by conjugate
# dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az))
dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay))
dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az))
dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax))
# apply translation vector
tx = dx + center[0]
ty = dy + center[1]
tz = dz + center[2]
i = 0
while i < atom_num:
# fetch atom
fx, fy, fz = atom_coords[i]
mask = atom_mask[i]
i += 1
# invisible atoms
if mask == 0:
continue
# point radius squared
r = point_radius
r2 = point_radius * point_radius
# quick cube bounds check
if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2:
continue
# value to add to this gridpoint
val = 0
if point_type == 0: # POINT_TYPE.EXP
# exponential sphere fill
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
if d2 > r2:
continue
# compute effect
val = math.exp((-2 * d2) / r2)
elif point_type == 1: # POINT_TYPE.SPHERE
# solid sphere fill
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
if d2 > r2:
continue
val = 1
elif point_type == 2: # POINT_TYPE.CUBE
# solid cube fill
val = 1
elif point_type == 3: # POINT_TYPE.GAUSSIAN
# (Ragoza, 2016)
#
# piecewise gaussian sphere fill
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
d = math.sqrt(d2)
if d > r * 1.5:
continue
elif d > r:
val = math.exp(-2.0) * ( (4*d2/r2) - (12*d/r) + 9 )
else:
val = math.exp((-2 * d2) / r2)
elif point_type == 4: # POINT_TYPE.LJ
# (Jimenez, 2017) - DeepSite
#
# LJ potential
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
d = math.sqrt(d2)
if d > r * 1.5:
continue
else:
val = 1 - math.exp(-((r/d)**12))
elif point_type == 5: # POINT_TYPE.DISCRETE
# nearest-gridpoint
# L1 distance
if abs(fx-tx) < (res/2) and abs(fy-ty) < (res/2) and abs(fz-tz) < (res/2):
val = 1
# add value to layers
for k in range(32):
if (mask >> k) & 1:
idx = (batch_idx, layer_offset+k, x, y, z)
if acc_type == 0: # ACC_TYPE.SUM
grid[idx] += val
elif acc_type == 1: # ACC_TYPE.MAX
grid[idx] = max(grid[idx], val)
def mol_gridify(grid, atom_coords, atom_mask, layer_offset, batch_idx,
width, res, center, rot, point_radius, point_type, acc_type):
width, res, center, rot, point_radius, point_type, acc_type,
cpu=False):
"""Wrapper around gpu_gridify.
(See gpu_gridify() for details)
"""
dw = ((width - 1) // GPU_DIM) + 1
gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)](
grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot, point_radius, point_type, acc_type
)
if cpu:
cpu_gridify(
grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot, point_radius, point_type, acc_type
)
else:
dw = ((width - 1) // GPU_DIM) + 1
gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)](
grid, len(atom_coords), atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot, point_radius, point_type, acc_type
)
def make_tensor(shape):
......@@ -371,7 +548,7 @@ def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5,
def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
conn, num_samples=32, width=24, res=1, fixed_rot=None,
point_radius=1.5, point_type=0, acc_type=0):
point_radius=1.5, point_type=0, acc_type=0, cpu=False):
"""Sample a raw batch with provided atom coordinates.
Args:
......@@ -383,14 +560,22 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
num_samples: number of rotations to sample
width: grid width
res: grid resolution
rec_channels: number of receptor channels
parent_channels: number of parent chanels
fixed_rot: None or a fixed 4-element rotation vector
point_radius: atom radius in Angstroms
point_type: shape of the atom densities
acc_type: atom density accumulation type
cpu: if True, generate batches with cpu_gridify
"""
B = num_samples
T = rec_typer.size() + lig_typer.size()
N = width
torch_grid, cuda_grid = make_tensor((B,T,N,N,N))
if cpu:
t = np.zeros((B,T,N,N,N))
torch_grid = t
cuda_grid = t
else:
torch_grid, cuda_grid = make_tensor((B,T,N,N,N))
r_mask = np.zeros(len(r_types), dtype=np.uint32)
p_mask = np.zeros(len(p_types), dtype=np.uint32)
......@@ -418,7 +603,8 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
rot=rot,
point_radius=point_radius,
point_type=point_type,
acc_type=acc_type
acc_type=acc_type,
cpu=cpu
)
mol_gridify(
......@@ -433,7 +619,8 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
rot=rot,
point_radius=point_radius,
point_type=point_type,
acc_type=acc_type
acc_type=acc_type,
cpu=cpu
)
return torch_grid
......@@ -19,10 +19,14 @@ import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import tqdm
import numpy as np
try:
import wandb
except:
pass
from leadopt.models.voxel import VoxelFingerprintNet
from leadopt.data_util import FragmentDataset, SharedFragmentDataset, FingerprintDataset, LIG_TYPER,\
REC_TYPER
......@@ -30,7 +34,7 @@ from leadopt.grid_util import get_batch
from leadopt.metrics import mse, bce, tanimoto, cos, top_k_acc,\
average_support, inside_support
from config import partitions, moad_partitions
from config import moad_partitions
def get_bios(p):
......@@ -309,36 +313,6 @@ class VoxelNet(LeadoptModel):
def load_data(self):
print('[*] Loading data...', flush=True)
# train_dat = FragmentDataset(
# self._args['fragments'],
# rec_typer=REC_TYPER[self._args['rec_typer']],
# lig_typer=LIG_TYPER[self._args['lig_typer']],
# # filter_rec=(
# # partitions.TRAIN if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.TRAIN)),
# filter_smi=set(moad_partitions.TRAIN_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )
# val_dat = FragmentDataset(
# self._args['fragments'],
# rec_typer=REC_TYPER[self._args['rec_typer']],
# lig_typer=LIG_TYPER[self._args['lig_typer']],
# # filter_rec=(
# # partitions.VAL if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.VAL)),
# filter_smi=set(moad_partitions.VAL_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )
dat = FragmentDataset(
self._args['fragments'],
rec_typer=REC_TYPER[self._args['rec_typer']],
......
......@@ -19,10 +19,10 @@ rdkit/openbabel utility scripts
import numpy as np
try:
import pybel
except:
from openbabel import pybel
# try:
# import pybel
# except:
# from openbabel import pybel
from rdkit import Chem
......@@ -164,32 +164,45 @@ def load_receptor(rec_path):
return rec
# def load_receptor_ob(rec_path):
# rec = next(pybel.readfile('pdb', rec_path))
# valid = [r for r in rec.residues if r.name != 'HOH']
# # map partial charge into byte range
# def conv_charge(x):
# x = max(x,-0.5)
# x = min(x,0.5)
# x += 0.5
# x *= 255
# x = int(x)
# return x
# coords = []
# types = []
# for v in valid:
# coords += [k.coords for k in v.atoms]
# types += [(
# k.atomicnum,
# int(k.OBAtom.IsAromatic()),
# int(k.OBAtom.IsHbondDonor()),
# int(k.OBAtom.IsHbondAcceptor()),
# conv_charge(k.OBAtom.GetPartialCharge())
# ) for k in v.atoms]
# return np.array(coords), np.array(types)
def load_receptor_ob(rec_path):
rec = next(pybel.readfile('pdb', rec_path))
valid = [r for r in rec.residues if r.name != 'HOH']
# map partial charge into byte range
def conv_charge(x):
x = max(x,-0.5)
x = min(x,0.5)
x += 0.5
x *= 255
x = int(x)
return x
coords = []
types = []
for v in valid:
coords += [k.coords for k in v.atoms]
types += [(
k.atomicnum,
int(k.OBAtom.IsAromatic()),
int(k.OBAtom.IsHbondDonor()),
int(k.OBAtom.IsHbondAcceptor()),
conv_charge(k.OBAtom.GetPartialCharge())
) for k in v.atoms]
return np.array(coords), np.array(types)
rec = load_receptor(rec_path)
coords = get_coords(rec)
types = np.array(get_types(rec))
types = np.concatenate([
types.reshape(-1,1),
np.zeros((len(types), 4))
], 1)
return coords, types
def get_connection_point(frag):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment