grid_util.py 8.9 KB
Newer Older
hgarrereyn's avatar
hgarrereyn committed
1
2
3
"""
Contains code for gpu-accelerated grid generation.
"""
Harrison Green's avatar
init  
Harrison Green committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17

import math
import ctypes

import torch
import numba
import numba.cuda
import numpy as np


GPU_DIM = 8


@numba.cuda.jit
hgarrereyn's avatar
hgarrereyn committed
18
19
20
def gpu_gridify(grid, atom_num, atom_coords, atom_layers, layer_offset,
                batch_idx, width, res, center, rot):
    """Adds atoms to the grid in a GPU kernel.
jdurrant's avatar
jdurrant committed
21

hgarrereyn's avatar
hgarrereyn committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    This kernel converts atom coordinate information to 3d voxel information.
    Each GPU thread is responsible for one specific grid point. This function
    receives a list of atomic coordinates and atom layers and simply iterates
    over the list to find nearby atoms and add their effect.

    Voxel information is stored in a 5D tensor of type: BxTxNxNxN where:
        B = batch size
        T = number of atom types (receptor + ligand)
        N = grid width (in gridpoints)

    Each invocation of this function will write information to a specific batch
    index specified by batch_idx. Additionally, the layer_offset parameter can
    be set to specify a fixed offset to add to each atom_layer item.

    How it works:
    1. Each GPU thread controls a single gridpoint. This gridpoint coordinate
        is translated to a "real world" coordinate by applying rotation and
        translation vectors.
    2. Each thread iterates over the list of atoms and checks for atoms within
        a threshold to add to the grid.

    Args:
        grid: DeviceNDArray tensor where grid information is stored
        atom_num: number of atoms
        atom_coords: array containing (x,y,z) atom coordinates
        atom_layers: array containing (idx) offsets that specify which layer to
            store this atom. (-1 can be used to ignore an atom)
        layer_offset: a fixed ofset added to each atom layer index
        batch_idx: index specifiying which batch to write information to
        width: number of grid points in each dimension
        res: distance between neighboring grid points in angstroms
            (1 == gridpoint every angstrom)
            (0.5 == gridpoint every half angstrom, e.g. tighter grid)
        center: (x,y,z) coordinate of grid center
        rot: (x,y,z,y) rotation quaternion
    """
Harrison Green's avatar
init  
Harrison Green committed
58
    x,y,z = numba.cuda.grid(3)
jdurrant's avatar
jdurrant committed
59

Harrison Green's avatar
init  
Harrison Green committed
60
61
62
63
    # center around origin
    tx = x - (width/2)
    ty = y - (width/2)
    tz = z - (width/2)
jdurrant's avatar
jdurrant committed
64

Harrison Green's avatar
init  
Harrison Green committed
65
66
67
68
69
70
71
72
73
74
    # scale by resolution
    tx = tx * res
    ty = ty * res
    tz = tz * res

    # apply rotation vector
    aw = rot[0]
    ax = rot[1]
    ay = rot[2]
    az = rot[3]
jdurrant's avatar
jdurrant committed
75

Harrison Green's avatar
init  
Harrison Green committed
76
77
78
79
    bw = 0
    bx = tx
    by = ty
    bz = tz
jdurrant's avatar
jdurrant committed
80

Harrison Green's avatar
init  
Harrison Green committed
81
82
83
84
85
    # multiply by rotation vector
    cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz)
    cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by)
    cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz)
    cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx)
jdurrant's avatar
jdurrant committed
86

Harrison Green's avatar
init  
Harrison Green committed
87
88
89
90
91
    # multiply by conjugate
    # dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az))
    dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay))
    dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az))
    dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax))
jdurrant's avatar
jdurrant committed
92

Harrison Green's avatar
init  
Harrison Green committed
93
94
95
96
    # apply translation vector
    tx = dx + center[0]
    ty = dy + center[1]
    tz = dz + center[2]
jdurrant's avatar
jdurrant committed
97

Harrison Green's avatar
init  
Harrison Green committed
98
99
100
101
    i = 0
    while i < atom_num:
        # fetch atom
        fx, fy, fz = atom_coords[i]
hgarrereyn's avatar
hgarrereyn committed
102
        ft = atom_layers[i]
Harrison Green's avatar
init  
Harrison Green committed
103
        i += 1
jdurrant's avatar
jdurrant committed
104

Harrison Green's avatar
init  
Harrison Green committed
105
106
107
        # invisible atoms
        if ft == -1:
            continue
jdurrant's avatar
jdurrant committed
108

hgarrereyn's avatar
hgarrereyn committed
109
        # fixed radius (^2)
Harrison Green's avatar
init  
Harrison Green committed
110
        r2 = 4
jdurrant's avatar
jdurrant committed
111

hgarrereyn's avatar
hgarrereyn committed
112
        # quick cube bounds check
Harrison Green's avatar
init  
Harrison Green committed
113
114
        if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2:
            continue
jdurrant's avatar
jdurrant committed
115

Harrison Green's avatar
init  
Harrison Green committed
116
117
        # compute squared distance to atom
        d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
jdurrant's avatar
jdurrant committed
118

Harrison Green's avatar
init  
Harrison Green committed
119
120
        # compute effect
        v = math.exp((-2 * d2) / r2)
jdurrant's avatar
jdurrant committed
121

Harrison Green's avatar
init  
Harrison Green committed
122
123
        # add effect
        if d2 < r2:
hgarrereyn's avatar
hgarrereyn committed
124
125
126
127
128
129
            grid[batch_idx, layer_offset+ft, x, y, z] += v


def mol_gridify(grid, atom_coords, atom_layers, layer_offset, batch_idx,
                width, res, center, rot):
    """Wrapper around gpu_gridify.
jdurrant's avatar
jdurrant committed
130

hgarrereyn's avatar
hgarrereyn committed
131
132
133
134
135
136
137
    (See gpu_gridify() for details)
    """
    dw = ((width - 1) // GPU_DIM) + 1
    gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)](
        grid, len(atom_coords), atom_coords, atom_layers, layer_offset,
        batch_idx, width, res, center, rot
    )
Harrison Green's avatar
init  
Harrison Green committed
138
139
140


def make_tensor(shape):
hgarrereyn's avatar
hgarrereyn committed
141
142
143
144
145
146
147
148
    """Creates a pytorch tensor and numba array with shared GPU memory backing.

    Args:
        shape: the shape of the array

    Returns:
        (torch_arr, cuda_arr)
    """
Harrison Green's avatar
init  
Harrison Green committed
149
150
    # get cuda context
    ctx = numba.cuda.cudadrv.driver.driver.get_active_context()
jdurrant's avatar
jdurrant committed
151

Harrison Green's avatar
init  
Harrison Green committed
152
153
154
155
156
    # setup tensor on gpu
    t = torch.zeros(size=shape, dtype=torch.float32).cuda()

    memory = numba.cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel() * 4)
    cuda_arr = numba.cuda.cudadrv.devicearray.DeviceNDArray(
jdurrant's avatar
jdurrant committed
157
158
159
160
        t.size(),
        [i*4 for i in t.stride()],
        np.dtype('float32'),
        gpu_data=memory,
Harrison Green's avatar
init  
Harrison Green committed
161
162
        stream=torch.cuda.current_stream().cuda_stream
    )
jdurrant's avatar
jdurrant committed
163

Harrison Green's avatar
init  
Harrison Green committed
164
    return (t, cuda_arr)
jdurrant's avatar
jdurrant committed
165

Harrison Green's avatar
init  
Harrison Green committed
166
167

def rand_rot():
hgarrereyn's avatar
hgarrereyn committed
168
    """Returns a random uniform quaternion rotation."""
Harrison Green's avatar
init  
Harrison Green committed
169
170
171
    q = np.random.normal(size=4) # sample quaternion from normal distribution
    q = q / np.sqrt(np.sum(q**2)) # normalize
    return q
jdurrant's avatar
jdurrant committed
172

Harrison Green's avatar
init  
Harrison Green committed
173

hgarrereyn's avatar
hgarrereyn committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def get_batch(data, rec_channels, parent_channels, batch_size=16, batch_set=None,
              width=48, res=0.5, ignore_receptor=False, ignore_parent=False):
    """Builds a batch grid from a FragmentDataset.

    Args:
        data: a FragmentDataset object
        rec_channels: number of receptor channels
        parent_channels: number of parent channels
        batch_size: size of the batch
        batch_set: if not None, specify a list of data indexes to use for each
            item in the batch
        width: grid width
        res: grid resolution
        ignore_receptor: if True, ignore receptor atoms
        ignore_parent: if True, ignore parent atoms

    Returns: (torch_grid, batch_set)
        torch_grid: pytorch Tensor with voxel information
        examples: list of examples used
    """
Harrison Green's avatar
init  
Harrison Green committed
194
195
    assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"

hgarrereyn's avatar
hgarrereyn committed
196
197
198
199
200
    dim = 0
    if not ignore_receptor:
        dim += rec_channels
    if not ignore_parent:
        dim += parent_channels
Harrison Green's avatar
init  
Harrison Green committed
201
202

    # create a tensor with shared memory on the gpu
hgarrereyn's avatar
hgarrereyn committed
203
    torch_grid, cuda_grid = make_tensor((batch_size, dim, width, width, width))
jdurrant's avatar
jdurrant committed
204

Harrison Green's avatar
init  
Harrison Green committed
205
206
    if batch_set is None:
        batch_set = np.random.choice(len(data), size=batch_size, replace=False)
jdurrant's avatar
jdurrant committed
207

hgarrereyn's avatar
hgarrereyn committed
208
209
210
211
    examples = [data[idx] for idx in batch_set]

    for i in range(len(examples)):
        example = examples[i]
Harrison Green's avatar
init  
Harrison Green committed
212
        rot = rand_rot()
jdurrant's avatar
jdurrant committed
213

Harrison Green's avatar
init  
Harrison Green committed
214
        if ignore_receptor:
hgarrereyn's avatar
hgarrereyn committed
215
216
217
218
219
220
221
222
223
224
225
            mol_gridify(
                cuda_grid,
                example['p_coords'],
                example['p_types'],
                layer_offset=0,
                batch_idx=i,
                width=width,
                res=res,
                center=example['conn'],
                rot=rot
            )
Harrison Green's avatar
init  
Harrison Green committed
226
        elif ignore_parent:
hgarrereyn's avatar
hgarrereyn committed
227
228
229
230
231
232
233
234
235
236
237
            mol_gridify(
                cuda_grid,
                example['r_coords'],
                example['r_types'],
                layer_offset=0,
                batch_idx=i,
                width=width,
                res=res,
                center=example['conn'],
                rot=rot
            )
Harrison Green's avatar
init  
Harrison Green committed
238
        else:
hgarrereyn's avatar
hgarrereyn committed
239
            mol_gridify(
jdurrant's avatar
jdurrant committed
240
                cuda_grid,
hgarrereyn's avatar
hgarrereyn committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
                example['p_coords'],
                example['p_types'],
                layer_offset=0,
                batch_idx=i,
                width=width,
                res=res,
                center=example['conn'],
                rot=rot
            )
            mol_gridify(
                cuda_grid,
                example['r_coords'],
                example['r_types'],
                layer_offset=parent_channels,
                batch_idx=i,
                width=width,
                res=res,
                center=example['conn'],
                rot=rot
            )
jdurrant's avatar
jdurrant committed
261

hgarrereyn's avatar
hgarrereyn committed
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    return torch_grid, examples


def get_raw_batch(r_coords, r_types, p_coords, p_types, conn, num_samples=32,
                  width=24, res=1, rec_channels=9, parent_channels=9):
    """Sample a raw batch with provided atom coordinates.

    Args:
        r_coords: receptor coordinates
        r_types: receptor types (layers)
        p_coords: parent coordinates
        p_types: parent types (layers)
        conn: (x,y,z) connection point
        num_samples: number of rotations to sample
        width: grid width
        res: grid resolution
        rec_channels: number of receptor channels
        parent_channels: number of parent chanels
    """
    B = num_samples
    T = rec_channels + parent_channels
    N = width
jdurrant's avatar
jdurrant committed
284

hgarrereyn's avatar
hgarrereyn committed
285
    torch_grid, cuda_grid = make_tensor((B,T,N,N,N))
jdurrant's avatar
jdurrant committed
286

Harrison Green's avatar
init  
Harrison Green committed
287
288
    for i in range(num_samples):
        rot = rand_rot()
hgarrereyn's avatar
hgarrereyn committed
289
290
291
292
        mol_gridify(cuda_grid, p_coords, p_types, layer_offset=0, batch_idx=i,
                        width=width, res=res, center=conn, rot=rot)
        mol_gridify(cuda_grid, r_coords, r_types, layer_offset=p_dim, batch_idx=i,
                        width=width, res=res, center=conn, rot=rot)
jdurrant's avatar
jdurrant committed
293

hgarrereyn's avatar
hgarrereyn committed
294
    return torch_grid