Commit 9399243e authored by hgarrereyn's avatar hgarrereyn
Browse files

js gridder fixes

parent d1cbd5c8
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -207,6 +207,160 @@ def mol_gridify(
return grid
def grid_to_real(x,y,z,width,res,center,rot):
"""Convert a grid (x,y,z) coordinate to real world coordinate."""
half_width = width / 2
# Center around origin.
tx = x - half_width
ty = y - half_width
tz = z - half_width
# Scale by resolution.
tx = tx * res
ty = ty * res
tz = tz * res
# Apply rotation vector.
aw = rot[0]
ax = rot[1]
ay = rot[2]
az = rot[3]
# bw = 0
bx = tx
by = ty
bz = tz
# Multiply by rotation vector.
cw = -(ax * bx) - (ay * by) - (az * bz)
cx = (aw * bx) + (ay * bz) - (az * by)
cy = (aw * by) + (az * bx) - (ax * bz)
cz = (aw * bz) + (ax * by) - (ay * bx)
# Multiply by conjugate.
# dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az))
dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay))
dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az))
dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax))
# Apply translation vector.
tx = dx + center.x # [0]
ty = dy + center.y # [1]
tz = dz + center.z # [2]
return (tx, ty, tz)
def filter_atoms(atom_coords, atom_layers, width, res, center, rot):
"""Filter atoms based on a grid bounding box.
Returns (filt_coords, filt_layers).
"""
PAD = 1.75
# Compute grid extremes.
ax, ay, az = grid_to_real(0,0,0,width,res,center,rot)
bx, by, bz = grid_to_real(width-1,width-1,width-1,width,res,center,rot)
filt_coords = []
filt_layers = []
# See TODO below.
assert rot == [1,0,0,0]
for i in range(len(atom_coords)):
x,y,z = atom_coords[i]
# TODO: this bounds check only works for a rotation vector of: [1,0,0,0]
# An alternative approach is to define real_to_grid and compute the
# nearest gridpoint for each atom.
if (
x > ax - PAD and
y > ay - PAD and
z > az - PAD and
x < bx + PAD and
y < by + PAD and
z < bz + PAD
):
filt_coords.append(atom_coords[i])
filt_layers.append(atom_layers[i])
return filt_coords, filt_layers
def mol_gridify2(
grid,
atom_coords,
atom_layers,
layer_offset,
num_layers_to_consider,
width,
res,
center,
rot,
):
# Filter atoms based on the grid bounding box.
filt_coords, filt_layers = filter_atoms(
atom_coords,
atom_layers,
width,
res,
center,
rot
)
# Fixed atom radius (r = 1.75).
r = 1.75
r2 = r * r
half_width = width / 2
# Assume rectangular tensor.
d_batch = len(grid)
d_layers = len(grid[0])
# Currently we assume the grid has a single batch entry and we will write to
# this index. To support multi-sample grids, we would probably add a
# "batch_index" parameter here and invoke mol_gridify multiple times like
# in the original grid_util.
assert d_batch == 1
# Iterate over grid points.
for x in range(width):
print(x)
for y in range(width):
for z in range(width):
# Compute the effective grid point position.
tx, ty, tz = grid_to_real(x,y,z,width,res,center,rot)
# Now compare the grid point location to each atom position.
for i in range(len(filt_coords)):
nx, ny, nz = filt_coords[i]
# print(nx,ny,nz)
# Distance squared.
d2 = (nx-tx)**2 + (ny-ty)**2 + (nz-tz)**2
if d2 > r2:
continue
ft = filt_layers[i]
# Compute effect.
# Point type: 0 (effect(d,r) = exp((-2 * d^2) / r^2))
v = math.exp((-2 * d2) / r2)
# add effect
# Acc type: 0 (sum overlapping points)
# TODO: if we implement multi-sample grids, replace 0 with
# batch_index or similar.
grid[0][layer_offset + ft][x][y][z] += v
return grid
def flatten_tensor(grid, shape):
flat = []
for i1 in range(shape[0]):
......@@ -314,12 +468,12 @@ def get_raw_batch(
# for i in range(num_samples):
rot = rand_rot()
grid = mol_gridify(
grid = mol_gridify2(
grid, p_coords, p_types, 0, parent_channels, width, res, conn, rot,
)
# TODO: Harrison should check. parent_channels was p_dim.
grid = mol_gridify(
grid = mol_gridify2(
grid, r_coords, r_types, parent_channels, rec_channels, width, res, conn, rot,
)
......
......@@ -72,8 +72,9 @@ def make_grid(receptor: str, ligand: str, grid_center: list) -> None:
for parent, frag in frags:
# use ligand directly (already fragmented)
# compute parent coords and layers
parent_coords, parent_layers = mol_to_points(parent, None, note_sulfur=False)
parent_coords, parent_layers = mol_to_points(lig, None, note_sulfur=False)
# find connection point
# __pragma__ ('skip')
......@@ -101,7 +102,7 @@ def make_grid(receptor: str, ligand: str, grid_center: list) -> None:
)
# __pragma__ ('skip')
print(json.dumps(grid))
return json.dumps(grid)
# __pragma__ ('noskip')
"""?
......@@ -112,7 +113,7 @@ def make_grid(receptor: str, ligand: str, grid_center: list) -> None:
if __name__ == "__main__":
# __pragma__ ('skip')
print(
grid = (
make_grid(
# "./1b6l/1b6l_protein.pdb",
"11gs/11gs_protein.pdb",
......@@ -122,6 +123,7 @@ if __name__ == "__main__":
[14.62, 9.944, 24.471],
)
)
open('./11gs/mol_gridify2.json', 'w').write(grid)
# __pragma__ ('noskip')
"""?
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment