Commit 83b3c537 authored by hgarrereyn's avatar hgarrereyn
Browse files

switch to use atom typers

parent c70e2963
......@@ -21,146 +21,119 @@ import tqdm
#
# If the layer index is -1, the atom is ignored.
def rec_typer_single(t):
num, hacc, hdon, aro, _ = t
if num != 1:
return 0
else:
return -1
def rec_typer_single_h(t):
return 0
def rec_typer_basic(t):
BASIC_TYPES = [6,7,8,16]
num, hacc, hdon, aro, _ = t
if num in BASIC_TYPES:
return BASIC_TYPES.index(num)
else:
return -1
def rec_typer_basic_h(t):
BASIC_TYPES = [1,6,7,8,16]
num, hacc, hdon, aro, _ = t
if num in BASIC_TYPES:
return BASIC_TYPES.index(num)
else:
return -1
# aro, hdon, hacc
REC_DESC = [
(6,0,0,0), # C
(6,1,0,0), # C-aro
(7,0,0,0), # N
(7,0,0,1), # N-hacc
(7,0,1,0), # N-hdon
(7,1,0,0), # N-aro
(7,0,1,1), # N-hdon-hacc
(7,1,0,1), # N-aro-hacc
(7,1,1,0), # N-aro-hdon
(8,0,0,0), # O
(8,0,0,1), # O-hacc
(8,0,1,1), # O-hdon-hacc
(16,0,0,0), #S
]
REC_DESC_H = [
(1,0,0,0), # H
(6,0,0,0), # C
(6,1,0,0), # C-hacc
(7,0,0,0), # N
(7,0,0,1), # N-aro
(7,0,1,0), # N-hdon
(7,1,0,0), # N-hacc
(7,0,1,1), # N-hdon-aro
(7,1,0,1), # N-hacc-aro
(7,1,1,0), # N-hacc-hdon
(8,0,0,0), # O
(8,0,0,1), # O-aro
(8,0,1,1), # O-hdon-aro
(16,0,0,0), #S
]
def rec_typer_desc(t):
num, hacc, hdon, aro, _ = t
f = (num,hacc,hdon,aro)
if f in REC_DESC:
return REC_DESC.index(f)
else:
return -1
def rec_typer_desc_h(t):
num, hacc, hdon, aro, _ = t
f = (num,hacc,hdon,aro)
if f in REC_DESC_H:
return REC_DESC_H.index(f)
else:
return -1
def lig_typer_single(t):
if t != 1:
return 0
else:
return -1
def lig_typer_single_h(t):
return 0
class AtomTyper(object):
def __init__(self, fn, num_layers):
"""Initialize an atom typer.
Args:
fn: a function of type:
(atomic_num, aro, hdon, hacc, pcharge) -> (mask)
num_layers: number of output layers (<=32)
"""
self._fn = fn
self._num_layers = num_layers
def size(self):
return self._num_layers
def apply(self, *args):
return self._fn(*args)
class CondAtomTyper(AtomTyper):
def __init__(self, cond_func):
assert len(cond_func) <= 32
def _fn(*args):
v = 0
for k in range(len(cond_func)):
if cond_func[k](*args):
v |= 1 << k
return v
super(CondAtomTyper, self).__init__(_fn, len(cond_func))
def lig_typer_simple(t):
BASIC_TYPES = [6,7,8]
if t in BASIC_TYPES:
return BASIC_TYPES.index(t)
else:
return -1
def lig_typer_simple_h(t):
BASIC_TYPES = [1,6,7,8]
if t in BASIC_TYPES:
return BASIC_TYPES.index(t)
else:
return -1
def lig_typer_desc(t):
DESC_TYPES = [6,7,8,16,9,15,17,35,5,53]
if t in DESC_TYPES:
return DESC_TYPES.index(t)
else:
return -1
def lig_typer_desc_h(t):
DESC_TYPES = [1,6,7,8,16,9,15,17,35,5,53]
if t in DESC_TYPES:
return DESC_TYPES.index(t)
else:
return -1
REC_TYPER = {
'single': rec_typer_single,
'single_h': rec_typer_single_h,
'simple': rec_typer_basic,
'simple_h': rec_typer_basic_h,
'desc': rec_typer_desc,
'desc_h': rec_typer_desc_h
# 1 channel, no hydrogen
'single': CondAtomTyper([
lambda num, aro, hdon, hacc, pcharge: num not in [0,1]
]),
# 1 channel, including hydrogen
'single_h': CondAtomTyper([
lambda num, aro, hdon, hacc, pcharge: num != 0
]),
# (C,N,O,S,*)
'simple': CondAtomTyper([
lambda num, aro, hdon, hacc, pcharge: num == 6,
lambda num, aro, hdon, hacc, pcharge: num == 7,
lambda num, aro, hdon, hacc, pcharge: num == 8,
lambda num, aro, hdon, hacc, pcharge: num == 16,
lambda num, aro, hdon, hacc, pcharge: num not in [0,1,6,7,8,16],
]),
# (H,C,N,O,S,*)
'simple_h': CondAtomTyper([
lambda num, aro, hdon, hacc, pcharge: num == 1,
lambda num, aro, hdon, hacc, pcharge: num == 6,
lambda num, aro, hdon, hacc, pcharge: num == 7,
lambda num, aro, hdon, hacc, pcharge: num == 8,
lambda num, aro, hdon, hacc, pcharge: num == 16,
lambda num, aro, hdon, hacc, pcharge: num not in [0,1,6,7,8,16],
]),
# (aro, hdon, hacc, positive, negative, occ)
'meta': CondAtomTyper([
lambda num, aro, hdon, hacc, pcharge: bool(aro), # aromatic
lambda num, aro, hdon, hacc, pcharge: bool(hdon), # hydrogen donor
lambda num, aro, hdon, hacc, pcharge: bool(hacc), # hydrogen acceptor
lambda num, aro, hdon, hacc, pcharge: pcharge >= 128, # partial positive
lambda num, aro, hdon, hacc, pcharge: pcharge < 128, # partial negative
lambda num, aro, hdon, hacc, pcharge: num != 0, # occupancy
]),
# (aro, hdon, hacc, positive, negative, occ)
'meta_mix': CondAtomTyper([
lambda num, aro, hdon, hacc, pcharge: bool(aro), # aromatic
lambda num, aro, hdon, hacc, pcharge: bool(hdon), # hydrogen donor
lambda num, aro, hdon, hacc, pcharge: bool(hacc), # hydrogen acceptor
lambda num, aro, hdon, hacc, pcharge: pcharge >= 128, # partial positive
lambda num, aro, hdon, hacc, pcharge: pcharge < 128, # partial negative
lambda num, aro, hdon, hacc, pcharge: num != 0, # occupancy
lambda num, aro, hdon, hacc, pcharge: num == 1, # hydrogen
lambda num, aro, hdon, hacc, pcharge: num == 6, # carbon
lambda num, aro, hdon, hacc, pcharge: num == 7, # nitrogen
lambda num, aro, hdon, hacc, pcharge: num == 8, # oxygen
lambda num, aro, hdon, hacc, pcharge: num == 16, # sulfur
])
}
LIG_TYPER = {
'single': lig_typer_single,
'single_h': lig_typer_single_h,
'simple': lig_typer_simple,
'simple_h': lig_typer_simple_h,
'desc': lig_typer_desc,
'desc_h': lig_typer_desc_h
# 1 channel, no hydrogen
'single': CondAtomTyper([
lambda num: num not in [0,1]
]),
# 1 channel, including hydrogen
'single_h': CondAtomTyper([
lambda num: num != 0
]),
'simple': CondAtomTyper([
lambda num: num == 6, # carbon
lambda num: num == 7, # nitrogen
lambda num: num == 8, # oxygen
lambda num: num not in [0,1,6,7,8] # extra
]),
'simple_h': CondAtomTyper([
lambda num: num == 1, # hydrogen
lambda num: num == 6, # carbon
lambda num: num == 7, # nitrogen
lambda num: num == 8, # oxygen
lambda num: num not in [0,1,6,7,8] # extra
])
}
......@@ -169,13 +142,13 @@ class FragmentDataset(Dataset):
def __init__(self, fragment_file, rec_typer, lig_typer, filter_rec=None,
fdist_min=None, fdist_max=None, fmass_min=None,
fmass_max=None, verbose=False, skip_remap=False):
fmass_max=None, verbose=False, lazy_loading=True):
"""Initializes the fragment dataset.
Args:
fragment_file: path to fragments.h5
rec_typer: function to map receptor rows to layer index
lig_typer: function to map ligand rows to layer index
rec_typer: AtomTyper for receptor
lig_typer: AtomTyper for ligand
filter_rec: list of receptor ids to use (or None to use all)
skip_remap: if True, don't prepare atom type information
......@@ -185,8 +158,11 @@ class FragmentDataset(Dataset):
fmass_min: minimum fragment mass (Da)
fmass_max: maximum fragment mass (Da)
"""
self._rec_typer = rec_typer
self._lig_typer = lig_typer
self.verbose = verbose
self._skip_remap = skip_remap
self._lazy_loading = lazy_loading
self.rec = self._load_rec(fragment_file, rec_typer)
self.frag = self._load_fragments(fragment_file, lig_typer)
......@@ -206,10 +182,12 @@ class FragmentDataset(Dataset):
if self.verbose:
r = tqdm.tqdm(r, desc='Remap receptor atoms')
rec_remapped = np.zeros(len(rec_types)).astype(np.int32)
if not self._skip_remap:
rec_remapped = np.zeros(len(rec_types), dtype=np.uint32)
if not self._lazy_loading:
for i in r:
rec_remapped[i] = rec_typer(rec_types[i])
rec_remapped[i] = rec_typer.apply(*rec_types[i])
rec_loaded = np.zeros(len(rec_lookup)).astype(np.int32)
# create rec mapping
rec_mapping = {}
......@@ -221,7 +199,8 @@ class FragmentDataset(Dataset):
'rec_types': rec_types,
'rec_remapped': rec_remapped,
'rec_lookup': rec_lookup,
'rec_mapping': rec_mapping
'rec_mapping': rec_mapping,
'rec_loaded': rec_loaded
}
f.close()
......@@ -242,12 +221,13 @@ class FragmentDataset(Dataset):
frag_coords = frag_data[:,:3].astype(np.float32)
frag_types = frag_data[:,3].astype(np.int32)
frag_remapped = None
if self._skip_remap:
frag_remapped = np.zeros(len(frag_types))
else:
frag_remapped = np.vectorize(lig_typer)(frag_types)
frag_remapped = np.zeros(len(frag_types), dtype=np.uint32)
if not self._lazy_loading:
for i in range(len(frag_types)):
frag_remapped[i] = lig_typer.apply(frag_types[i])
frag_loaded = np.zeros(len(frag_lookup)).astype(np.int32)
# find and save connection point
r = range(len(frag_lookup))
if self.verbose:
......@@ -276,6 +256,7 @@ class FragmentDataset(Dataset):
'frag_smiles': frag_smiles, # f_idx -> smiles
'frag_mass': frag_mass, # f_idx -> mass
'frag_dist': frag_dist, # f_idx -> dist
'frag_loaded': frag_loaded
}
f.close()
......@@ -347,20 +328,41 @@ class FragmentDataset(Dataset):
_, r_start, r_end = self.rec['rec_lookup'][rec_idx]
# fetch data
f_coords = self.frag['frag_coords'][f_start:f_end]
f_types = self.frag['frag_remapped'][f_start:f_end]
# f_coords = self.frag['frag_coords'][f_start:f_end]
# f_types = self.frag['frag_types'][f_start:f_end]
p_coords = self.frag['frag_coords'][p_start:p_end]
p_types = self.frag['frag_remapped'][p_start:p_end]
r_coords = self.rec['rec_coords'][r_start:r_end]
r_types = self.rec['rec_remapped'][r_start:r_end]
if self._lazy_loading and self.frag['frag_loaded'][frag_idx] == 0:
frag_types = self.frag['frag_types']
frag_remapped = self.frag['frag_remapped']
# load parent
for i in range(p_start, p_end):
frag_remapped[i] = self._lig_typer.apply(frag_types[i])
self.frag['frag_loaded'][frag_idx] = 1
if self._lazy_loading and self.rec['rec_loaded'][rec_idx] == 0:
rec_types = self.rec['rec_types']
rec_remapped = self.rec['rec_remapped']
# load receptor
for i in range(r_start, r_end):
rec_remapped[i] = self._rec_typer.apply(*rec_types[i])
self.rec['rec_loaded'][rec_idx] = 1
p_mask = self.frag['frag_remapped'][p_start:p_end]
r_mask = self.rec['rec_remapped'][r_start:r_end]
return {
'f_coords': f_coords,
'f_types': f_types,
# 'f_coords': f_coords,
# 'f_types': f_types,
'p_coords': p_coords,
'p_types': p_types,
'p_types': p_mask,
'r_coords': r_coords,
'r_types': r_types,
'r_types': r_mask,
'conn': conn,
'smiles': smiles
}
......@@ -375,6 +377,12 @@ class FragmentDataset(Dataset):
return list(valid_smiles)
def lig_layers(self):
return self._lig_typer.size()
def rec_layers(self):
return self._rec_typer.size()
class FingerprintDataset(Dataset):
......@@ -422,175 +430,3 @@ class FingerprintDataset(Dataset):
fp[i] = self.fingerprints['fingerprint_data'][fp_idx]
return torch.Tensor(fp)
# class FragmentDataset2(Dataset):
# '''
# Utility class to work with the packed fragments.h5 format
# (no fingerprints)
# '''
# def __init__(self, fragment_file, rec_typer, lig_typer, filter_rec=None,
# fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None, verbose=False):
# '''
# Initialize the fragment dataset
# Params:
# - fragment_file: path to fragments.h5
# - rec_typer: function to map receptor rows to layer index
# - lig_typer: function to map ligand rows to layer index
# - filter_rec: list of receptor ids to use (or None to use all)
# Filtering options:
# - fdist_min: minimum fragment distance
# - fdist_max: maximum fragment distance
# - fmass_min: minimum fragment mass (Da)
# - fmass_max: maximum fragment mass (Da)
# '''
# self.verbose = verbose
# # load receptor/fragment information
# self.rec = self.load_rec(fragment_file, rec_typer)
# self.frag = self.load_fragments(fragment_file, lig_typer)
# # keep track of valid examples
# valid_mask = np.ones(self.frag['frag_lookup'].shape[0]).astype(np.bool)
# # filter by receptor id
# if filter_rec is not None:
# valid_rec = np.vectorize(lambda k: k[0].decode('ascii') in filter_rec)(self.frag['frag_lookup'])
# valid_mask *= valid_rec
# # filter by fragment distance
# if fdist_min is not None:
# valid_mask[self.frag['frag_dist'] < fdist_min] = 0
# if fdist_max is not None:
# valid_mask[self.frag['frag_dist'] > fdist_max] = 0
# # filter by fragment mass
# if fmass_min is not None:
# valid_mask[self.frag['frag_mass'] < fmass_min] = 0
# if fmass_max is not None:
# valid_mask[self.frag['frag_mass'] > fmass_max] = 0
# # convert to a list of indexes
# self.valid_idx = np.where(valid_mask)[0]
# def load_rec(self, fragment_file, rec_typer):
# '''Load receptor information'''
# f = h5py.File(fragment_file, 'r')
# rec_coords = f['rec_coords'][()]
# rec_types = f['rec_types'][()]
# rec_lookup = f['rec_lookup'][()]
# r = range(len(rec_types))
# if self.verbose:
# r = tqdm.tqdm(r, desc='Remap receptor atoms')
# rec_remapped = np.zeros(len(rec_types)).astype(np.int32)
# for i in r:
# rec_remapped[i] = rec_typer(rec_types[i])
# # create rec mapping
# rec_mapping = {}
# for i in range(len(rec_lookup)):
# rec_mapping[rec_lookup[i][0].decode('ascii')] = i
# rec = {
# 'rec_coords': rec_coords,
# 'rec_types': rec_types,
# 'rec_remapped': rec_remapped,
# 'rec_lookup': rec_lookup,
# 'rec_mapping': rec_mapping
# }
# f.close()
# return rec
# def load_fragments(self, fragment_file, lig_typer):
# '''Load fragment information'''
# f = h5py.File(fragment_file, 'r')
# frag_data = f['frag_data'][()]
# frag_lookup = f['frag_lookup'][()]
# frag_smiles = f['frag_smiles'][()]
# frag_mass = f['frag_mass'][()]
# frag_dist = f['frag_dist'][()]
# # unpack frag data into separate structures
# frag_coords = frag_data[:,:3].astype(np.float32)
# frag_types = frag_data[:,3].astype(np.int32)
# frag_remapped = np.vectorize(lig_typer)(frag_types)
# # find and save connection point
# r = range(len(frag_lookup))
# if self.verbose:
# r = tqdm.tqdm(r, desc='Frag connection point')
# frag_conn = np.zeros((len(frag_lookup), 3))
# for i in r:
# _,f_start,f_end,_,_ = frag_lookup[i]
# fdat = frag_data[f_start:f_end]
# found = False
# for j in range(len(fdat)):
# if fdat[j][3] == 0:
# frag_conn[i,:] = tuple(fdat[j])[:3]
# found = True
# break
# assert found, "missing fragment connection point at %d" % i
# frag = {
# 'frag_coords': frag_coords, # d_idx -> (x,y,z)
# 'frag_types': frag_types, # d_idx -> (type)
# 'frag_remapped': frag_remapped, # d_idx -> (layer)
# 'frag_lookup': frag_lookup, # f_idx -> (rec_id, fstart, fend, pstart, pend)
# 'frag_conn': frag_conn, # f_idx -> (x,y,z)
# 'frag_smiles': frag_smiles, # f_idx -> smiles
# 'frag_mass': frag_mass, # f_idx -> mass
# 'frag_dist': frag_dist, # f_idx -> dist
# }
# f.close()
# return frag
# def __len__(self):
# '''returns the number of fragment examples'''
# return self.valid_idx.shape[0]
# def __getitem__(self, idx):
# '''
# retrieve the nth example
# returns (f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, fingerprint, extra)
# '''
# # convert to fragment index
# frag_idx = self.valid_idx[idx]
# # lookup fragment
# rec_id, f_start, f_end, p_start, p_end = self.frag['frag_lookup'][frag_idx]
# smiles = self.frag['frag_smiles'][frag_idx]
# conn = self.frag['frag_conn'][frag_idx]
# # lookup receptor
# rec_idx = self.rec['rec_mapping'][rec_id.decode('ascii')]
# _, r_start, r_end = self.rec['rec_lookup'][rec_idx]
# # fetch data
# f_coords = self.frag['frag_coords'][f_start:f_end]
# f_types = self.frag['frag_remapped'][f_start:f_end]
# p_coords = self.frag['frag_coords'][p_start:p_end]
# p_types = self.frag['frag_remapped'][p_start:p_end]
# r_coords = self.rec['rec_coords'][r_start:r_end]
# r_types = self.rec['rec_remapped'][r_start:r_end]
# return f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, smiles
\ No newline at end of file
......@@ -14,9 +14,24 @@ import numpy as np
GPU_DIM = 8
class POINT_TYPE(object):
EXP = 0 # simple exponential sphere fill
SPHERE = 1 # fixed sphere fill
CUBE = 2 # fixed cube fill
GAUSSIAN = 3 # continous piecewise expenential fill
LJ = 4
DISCRETE = 5
class ACC_TYPE(object):
SUM = 0
MAX = 1
@numba.cuda.jit
def gpu_gridify(grid, atom_num, atom_coords, atom_layers, layer_offset,
batch_idx, width, res, center, rot):
def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
batch_idx, width, res, center, rot,
point_radius, point_type, acc_type
):
"""Adds atoms to the grid in a GPU kernel.
This kernel converts atom coordinate information to 3d voxel information.
......@@ -44,8 +59,8 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_layers, layer_offset,
grid: DeviceNDArray tensor where grid information is stored
atom_num: number of atoms
atom_coords: array containing (x,y,z) atom coordinates
atom_layers: array containing (idx) offsets that specify which layer to
store this atom. (-1 can be used to ignore an atom)
atom_mask: uint32 array of size atom_num containing a destination
layer bitmask (i.e. if bit k is set, write atom to index k)
layer_offset: a fixed ofset added to each atom layer index
batch_idx: index specifiying which batch to write information to
width: number of grid points in each dimension
......@@ -99,41 +114,96 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_layers, layer_offset,
while i < atom_num:
# fetch atom
fx, fy, fz = atom_coords[i]
ft = atom_layers[i]
mask = atom_mask[i]
i += 1
# invisible atoms
if ft == -1:
if mask == 0:
continue