Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
jdurrant
deepfrag
Commits
38b5cee4
Commit
38b5cee4
authored
Jul 14, 2020
by
hgarrereyn
Browse files
work dump, add onnx stuff
parent
1c333a36
Changes
26
Expand all
Hide whitespace changes
Inline
Side-by-side
README.md
View file @
38b5cee4
# Lead Optimization
via Fragment Prediction
# Lead Optimization
#
Overview
#
Structure
-
`config`
: configuration information (eg. TRAIN/TEST partitions)
-
`data`
: training/inference data (see
[
`data/README.md`
](
data/README.md
)
)
...
...
config/_old_partitions.py
0 → 100644
View file @
38b5cee4
This diff is collapsed.
Click to expand it.
config/partitions.py
View file @
38b5cee4
This diff is collapsed.
Click to expand it.
launch.py
0 → 100644
View file @
38b5cee4
import
argparse
import
sys
import
os
import
subprocess
import
tempfile
RUN_DIR
=
'/zfs1/jdurrant/durrantlab/hag63/leadopt_pytorch'
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'-t'
,
'--time'
,
default
=
'10:00:00'
)
parser
.
add_argument
(
'-p'
,
'--partition'
,
default
=
'gtx1080'
)
parser
.
add_argument
(
'-m'
,
'--mem'
,
default
=
'16g'
)
parser
.
add_argument
(
'path'
)
parser
.
add_argument
(
'script'
)
args
=
parser
.
parse_args
()
run_path
=
os
.
path
.
join
(
RUN_DIR
,
args
.
path
)
if
os
.
path
.
exists
(
run_path
):
print
(
'[!] Run exists at %s'
%
run_path
)
overwrite
=
input
(
'- Overwrite? [Y/n]: '
)
if
overwrite
.
lower
()
==
'n'
:
print
(
'Exiting...'
)
exit
(
0
)
else
:
print
(
'[*] Creating run directory %s'
%
run_path
)
os
.
mkdir
(
run_path
)
script
=
'''#!/bin/bash
#SBATCH --job-name={name}
#SBATCH --output={run_path}/slurm_out.txt
#SBATCH --error={run_path}/slurm_err.txt
#SBATCH --time={time}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cluster=gpu
#SBATCH --partition={partition}
#SBATCH --mail-user=hag63@pitt.edu
#SBATCH --mail-type=END,FAIL
#SBATCH --gres=gpu:1
#SBATCH --mem={mem}
cd /ihome/jdurrant/hag63/cbio/leadopt_pytorch/
./setup.sh
export PYTHONPATH=/ihome/jdurrant/hag63/cbio/leadopt_pytorch/
# export WANDB_DIR=/ihome/jdurrant/hag63/wandb_abs
export WANDB_DISABLE_CODE=true
cd {run_path}
python /ihome/jdurrant/hag63/cbio/leadopt_pytorch/{script}
'''
.
format
(
name
=
'leadopt_%s'
%
args
.
path
,
run_path
=
run_path
,
time
=
args
.
time
,
partition
=
args
.
partition
,
mem
=
args
.
mem
,
script
=
args
.
script
)
print
(
'[*] Running script...'
)
with
tempfile
.
NamedTemporaryFile
(
'w'
)
as
f
:
f
.
write
(
script
)
f
.
flush
()
r
=
subprocess
.
run
(
'sbatch %s'
%
f
.
name
,
shell
=
True
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
)
print
(
r
)
if
__name__
==
'__main__'
:
main
()
launch.sh
0 → 100644
View file @
38b5cee4
#!/bin/bash
DEFAULT_RUN_PATH
=
/zfs1/jdurrant/durrantlab/hag63/leadopt_pytorch/
if
[[
$#
-ne
4
]]
;
then
echo
"Usage:
$0
<run_name> <gpu_partition> <**args>"
exit
0
fi
ABS_SCRIPT
=
$(
pwd
)
/train.py
# navigate to runs directory
RUNS_DIR
=
"
${
RUNS_DIR
:-
$DEFAULT_RUN_PATH
}
"
cd
$RUNS_DIR
if
[[
-d
$1
]]
;
then
echo
"Warning: run directory
$1
already exists!"
exit
-1
fi
echo
"Creating run directory (
$1
)..."
mkdir
$1
echo
"Running script..."
sbatch
<<
EOT
#!/bin/bash
#SBATCH --job-name=
$1
#SBATCH --output=
$RUNS_DIR
/
$1
/slurm_out.txt
#SBATCH --error=
$RUNS_DIR
/
$1
/slurm_err.txt
#SBATCH --time=10:00:00
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cluster=gpu
#SBATCH --partition=
$2
#SBATCH --mail-user=hag63@pitt.edu
#SBATCH --mail-type=END,FAIL
#SBATCH --gres=gpu:1
#SBATCH --mem=80g
cd /ihome/jdurrant/hag63/cbio/leadopt_pytorch/
source ./setup.sh
PYTHON_PATH=
$PYTHON_PATH
:/ihome/jdurrant/hag63/cbio/leadopt_pytorch/
cd
$RUNS_DIR
/
$1
python train.py
$4
EOT
leadopt.py
View file @
38b5cee4
...
...
@@ -13,52 +13,9 @@ import h5py
from
leadopt.models.voxel
import
VoxelFingerprintNet2
from
leadopt.infer
import
infer_all
from
leadopt.pretrained
import
MODELS
class
SavedModel
(
object
):
model_class
=
None
model_args
=
None
@
classmethod
def
load
(
cls
,
path
):
m
=
cls
.
model_class
(
**
cls
.
model_args
).
cuda
()
m
.
load_state_dict
(
torch
.
load
(
path
))
m
.
eval
()
return
m
@
classmethod
def
get_fingerprints
(
cls
,
path
):
f
=
h5py
.
File
(
os
.
path
.
join
(
path
,
cls
.
fingerprint_data
),
'r'
)
data
=
f
[
'fingerprints'
][()]
smiles
=
f
[
'smiles'
][()]
f
.
close
()
return
data
,
smiles
class
V2_RDK_M150
(
SavedModel
):
model_class
=
VoxelFingerprintNet2
model_args
=
{
'in_channels'
:
18
,
'output_size'
:
2048
,
'batchnorm'
:
True
,
'sigmoid'
:
True
}
grid_width
=
24
grid_res
=
1
receptor_types
=
[
6
,
7
,
8
,
9
,
15
,
16
,
17
,
35
,
53
]
parent_types
=
[
6
,
7
,
8
,
9
,
15
,
16
,
17
,
35
,
53
]
fingerprint_data
=
'fingerprint_rdk_2048.h5'
MODELS
=
{
'rdk_m150'
:
V2_RDK_M150
}
def
main
():
parser
=
argparse
.
ArgumentParser
()
...
...
leadopt/data_util.py
View file @
38b5cee4
...
...
@@ -8,28 +8,161 @@ import os
from
torch.utils.data
import
DataLoader
,
Dataset
import
numpy
as
np
import
h5py
import
tqdm
# Default atomic numbers to use as grid layers
DEFAULT_TYPES
=
[
6
,
7
,
8
,
9
,
15
,
16
,
17
,
35
,
53
]
def
remap_atoms
(
atom_types
):
'''
Returns a function that maps an atomic number to layer index.
def
rec_typer_single
(
t
):
'''single channel no hydrogen'''
# unpack features
num
,
hacc
,
hdon
,
aro
,
_
=
t
if
num
!=
1
:
return
0
else
:
return
-
1
def
rec_typer_single_h
(
t
):
'''single channel with hydrogen'''
# unpack features
return
0
def
rec_typer_basic
(
t
):
'''simple types'''
BASIC_TYPES
=
[
6
,
7
,
8
,
16
]
num
,
hacc
,
hdon
,
aro
,
_
=
t
if
num
in
BASIC_TYPES
:
return
BASIC_TYPES
.
index
(
num
)
else
:
return
-
1
def
rec_typer_basic_h
(
t
):
'''simple types with hydrogen'''
BASIC_TYPES
=
[
1
,
6
,
7
,
8
,
16
]
num
,
hacc
,
hdon
,
aro
,
_
=
t
if
num
in
BASIC_TYPES
:
return
BASIC_TYPES
.
index
(
num
)
else
:
return
-
1
Params:
- atom_types: which atom types to use as layers
'''
atom_mapping
=
{
atom_types
[
i
]:
i
for
i
in
range
(
len
(
atom_types
))}
# aro, hdon, hacc
REC_DESC
=
[
(
6
,
0
,
0
,
0
),
# C
(
6
,
1
,
0
,
0
),
# C-aro
(
7
,
0
,
0
,
0
),
# N
(
7
,
0
,
0
,
1
),
# N-hacc
(
7
,
0
,
1
,
0
),
# N-hdon
(
7
,
1
,
0
,
0
),
# N-aro
(
7
,
0
,
1
,
1
),
# N-hdon-hacc
(
7
,
1
,
0
,
1
),
# N-aro-hacc
(
7
,
1
,
1
,
0
),
# N-aro-hdon
(
8
,
0
,
0
,
0
),
# O
(
8
,
0
,
0
,
1
),
# O-hacc
(
8
,
0
,
1
,
1
),
# O-hdon-hacc
(
16
,
0
,
0
,
0
),
#S
]
REC_DESC_H
=
[
(
1
,
0
,
0
,
0
),
# H
(
6
,
0
,
0
,
0
),
# C
(
6
,
1
,
0
,
0
),
# C-hacc
(
7
,
0
,
0
,
0
),
# N
(
7
,
0
,
0
,
1
),
# N-aro
(
7
,
0
,
1
,
0
),
# N-hdon
(
7
,
1
,
0
,
0
),
# N-hacc
(
7
,
0
,
1
,
1
),
# N-hdon-aro
(
7
,
1
,
0
,
1
),
# N-hacc-aro
(
7
,
1
,
1
,
0
),
# N-hacc-hdon
(
8
,
0
,
0
,
0
),
# O
(
8
,
0
,
0
,
1
),
# O-aro
(
8
,
0
,
1
,
1
),
# O-hdon-aro
(
16
,
0
,
0
,
0
),
#S
]
def
rec_typer_desc
(
t
):
'''descriptive types'''
num
,
hacc
,
hdon
,
aro
,
_
=
t
f
=
(
num
,
hacc
,
hdon
,
aro
)
if
f
in
REC_DESC
:
return
REC_DESC
.
index
(
f
)
else
:
return
-
1
def
rec_typer_desc_h
(
t
):
'''descriptive types'''
num
,
hacc
,
hdon
,
aro
,
_
=
t
f
=
(
num
,
hacc
,
hdon
,
aro
)
if
f
in
REC_DESC_H
:
return
REC_DESC_H
.
index
(
f
)
else
:
return
-
1
def
lig_typer_single
(
t
):
if
t
!=
1
:
return
0
else
:
return
-
1
def
lig_typer_single_h
(
t
):
return
0
def
lig_typer_simple
(
t
):
BASIC_TYPES
=
[
6
,
7
,
8
]
if
t
in
BASIC_TYPES
:
return
BASIC_TYPES
.
index
(
t
)
else
:
return
-
1
def
lig_typer_simple_h
(
t
):
BASIC_TYPES
=
[
1
,
6
,
7
,
8
]
def
f
(
x
):
if
x
in
atom_mapping
:
return
atom_mapping
[
x
]
else
:
return
-
1
if
t
in
BASIC_TYPES
:
return
BASIC_TYPES
.
index
(
t
)
else
:
return
-
1
return
f
def
lig_typer_desc
(
t
):
DESC_TYPES
=
[
6
,
7
,
8
,
16
,
9
,
15
,
17
,
35
,
5
,
53
]
if
t
in
DESC_TYPES
:
return
DESC_TYPES
.
index
(
t
)
else
:
return
-
1
def
lig_typer_desc_h
(
t
):
DESC_TYPES
=
[
1
,
6
,
7
,
8
,
16
,
9
,
15
,
17
,
35
,
5
,
53
]
if
t
in
DESC_TYPES
:
return
DESC_TYPES
.
index
(
t
)
else
:
return
-
1
REC_TYPER
=
{
'single'
:
rec_typer_single
,
'single_h'
:
rec_typer_single_h
,
'simple'
:
rec_typer_basic
,
'simple_h'
:
rec_typer_basic_h
,
'desc'
:
rec_typer_desc
,
'desc_h'
:
rec_typer_desc_h
}
LIG_TYPER
=
{
'single'
:
lig_typer_single
,
'single_h'
:
lig_typer_single_h
,
'simple'
:
lig_typer_simple
,
'simple_h'
:
lig_typer_simple_h
,
'desc'
:
lig_typer_desc
,
'desc_h'
:
lig_typer_desc_h
}
class
FragmentDataset
(
Dataset
):
...
...
@@ -37,15 +170,17 @@ class FragmentDataset(Dataset):
Utility class to work with the packed fragments.h5 format
'''
def
__init__
(
self
,
fragment_file
,
fingerprint_file
,
filter_rec
=
None
,
atom_types
=
DEFAULT_TYPES
,
fdist_min
=
None
,
fdist_max
=
None
,
fmass_min
=
None
,
fmass_max
=
None
):
def
__init__
(
self
,
fragment_file
,
fingerprint_file
,
rec_typer
,
lig_typer
,
filter_rec
=
None
,
fdist_min
=
None
,
fdist_max
=
None
,
fmass_min
=
None
,
fmass_max
=
None
,
verbose
=
False
):
'''
Initialize the fragment dataset
Params:
- fragment_file: path to fragments.h5
- fingerprint_file: path to fingerprints.h5
- rec_typer: function to map receptor rows to layer index
- lig_typer: function to map ligand rows to layer index
- filter_rec: list of receptor ids to use (or None to use all)
- atom_types: which atom types to use as layers
Filtering options:
- fdist_min: minimum fragment distance
...
...
@@ -53,9 +188,11 @@ class FragmentDataset(Dataset):
- fmass_min: minimum fragment mass (Da)
- fmass_max: maximum fragment mass (Da)
'''
self
.
verbose
=
verbose
# load receptor/fragment information
self
.
rec
=
self
.
load_rec
(
fragment_file
,
atom
_type
s
)
self
.
frag
=
self
.
load_fragments
(
fragment_file
,
atom
_type
s
)
self
.
rec
=
self
.
load_rec
(
fragment_file
,
rec
_type
r
)
self
.
frag
=
self
.
load_fragments
(
fragment_file
,
lig
_type
r
)
# load fingerprint information
self
.
fingerprints
=
self
.
load_fingerprints
(
fingerprint_file
)
...
...
@@ -91,19 +228,22 @@ class FragmentDataset(Dataset):
self
.
valid_fingerprints
=
self
.
compute_valid_fingerprints
()
def
load_rec
(
self
,
fragment_file
,
atom
_type
s
):
def
load_rec
(
self
,
fragment_file
,
rec
_type
r
):
'''Load receptor information'''
f
=
h5py
.
File
(
fragment_file
,
'r'
)
rec_data
=
f
[
'rec_data'
][()]
rec_coords
=
f
[
'rec_coords'
][()]
rec_types
=
f
[
'rec_types'
][()]
rec_lookup
=
f
[
'rec_lookup'
][()]
# unpack rec data into separate structures
rec_coords
=
rec_data
[:,:
3
].
astype
(
np
.
float32
)
rec_types
=
rec_data
[:,
3
].
reshape
(
-
1
,
1
).
astype
(
np
.
int32
)
rec_remapped
=
np
.
vectorize
(
remap_atoms
(
atom_types
))(
rec_types
)
r
=
range
(
len
(
rec_types
))
if
self
.
verbose
:
r
=
tqdm
.
tqdm
(
r
,
desc
=
'Remap receptor atoms'
)
rec_remapped
=
np
.
zeros
(
len
(
rec_types
)).
astype
(
np
.
int32
)
for
i
in
r
:
rec_remapped
[
i
]
=
rec_typer
(
rec_types
[
i
])
# create rec mapping
rec_mapping
=
{}
for
i
in
range
(
len
(
rec_lookup
)):
...
...
@@ -121,7 +261,7 @@ class FragmentDataset(Dataset):
return
rec
def
load_fragments
(
self
,
fragment_file
,
atom
_type
s
):
def
load_fragments
(
self
,
fragment_file
,
lig
_type
r
):
'''Load fragment information'''
f
=
h5py
.
File
(
fragment_file
,
'r'
)
...
...
@@ -134,13 +274,17 @@ class FragmentDataset(Dataset):
# unpack frag data into separate structures
frag_coords
=
frag_data
[:,:
3
].
astype
(
np
.
float32
)
frag_types
=
frag_data
[:,
3
].
reshape
(
-
1
,
1
).
astype
(
np
.
int32
)
frag_types
=
frag_data
[:,
3
].
astype
(
np
.
int32
)
frag_remapped
=
np
.
vectorize
(
remap_atoms
(
atom
_type
s
)
)(
frag_types
)
frag_remapped
=
np
.
vectorize
(
lig
_type
r
)(
frag_types
)
# find and save connection point
r
=
range
(
len
(
frag_lookup
))
if
self
.
verbose
:
r
=
tqdm
.
tqdm
(
r
,
desc
=
'Frag connection point'
)
frag_conn
=
np
.
zeros
((
len
(
frag_lookup
),
3
))
for
i
in
r
ange
(
len
(
frag_lookup
))
:
for
i
in
r
:
_
,
f_start
,
f_end
,
_
,
_
=
frag_lookup
[
i
]
fdat
=
frag_data
[
f_start
:
f_end
]
...
...
leadopt/grid_util.py
View file @
38b5cee4
...
...
@@ -67,7 +67,8 @@ def gpu_gridify(grid, width, res, center, rot, atom_num, atom_coords, atom_types
while
i
<
atom_num
:
# fetch atom
fx
,
fy
,
fz
=
atom_coords
[
i
]
ft
=
atom_types
[
i
][
0
]
# ft = atom_types[i][0]
ft
=
atom_types
[
i
]
i
+=
1
# invisible atoms
...
...
@@ -152,13 +153,15 @@ def mol_gridify(
)
def
get_batch
(
data
,
batch_set
=
None
,
batch_size
=
16
,
width
=
48
,
res
=
0.5
,
ignore_receptor
=
False
,
ignore_parent
=
False
,
include_freq
=
False
):
def
get_batch
(
data
,
rec_channels
,
parent_channels
,
batch_set
=
None
,
batch_size
=
16
,
width
=
48
,
res
=
0.5
,
ignore_receptor
=
False
,
ignore_parent
=
False
,
include_freq
=
False
):
assert
(
not
(
ignore_receptor
and
ignore_parent
)),
"Can't ignore parent and receptor!"
dim
=
18
if
ignore_receptor
or
ignore_parent
:
dim
=
9
dim
=
0
if
not
ignore_receptor
:
dim
+=
rec_channels
if
not
ignore_parent
:
dim
+=
parent_channels
# create a tensor with shared memory on the gpu
t
,
grid
=
make_tensor
((
batch_size
,
dim
,
width
,
width
,
width
))
...
...
@@ -182,7 +185,7 @@ def get_batch(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_rec
mol_gridify
(
grid
,
r_coords
,
r_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
0
)
else
:
mol_gridify
(
grid
,
p_coords
,
p_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
0
)
mol_gridify
(
grid
,
r_coords
,
r_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
9
)
mol_gridify
(
grid
,
r_coords
,
r_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
parent_channels
)
fingerprints
[
i
]
=
fp
freq
[
i
]
=
extra
[
'freq'
]
...
...
@@ -196,63 +199,63 @@ def get_batch(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_rec
return
t
,
t_fingerprints
,
batch_set
def
get_batch_dual
(
data
,
batch_set
=
None
,
batch_size
=
16
,
width
=
48
,
res
=
0.5
,
ignore_receptor
=
False
,
ignore_parent
=
False
):
#
def get_batch_dual(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False):
# get batch
t
,
fp
,
batch_set
=
get_batch
(
data
,
batch_set
,
batch_size
,
width
,
res
,
ignore_receptor
,
ignore_parent
)
#
# get batch
#
t, fp, batch_set = get_batch(data, batch_set, batch_size, width, res, ignore_receptor, ignore_parent)
f
=
data
.
fingerprints
[
'fingerprint_data'
]
#
f = data.fingerprints['fingerprint_data']
# corrupt fingerprints
false_fp
=
torch
.
clone
(
fp
)
for
i
in
range
(
batch_size
):
# idx = np.random.randint(fp.shape[1])
# false_fp[i,idx] = (1 - false_fp[i,idx]) # flip
idx
=
np
.
random
.
randint
(
f
.
shape
[
0
])
false_fp
[
i
]
=
torch
.
Tensor
(
f
[
idx
])
# replace
#
# corrupt fingerprints
#
false_fp = torch.clone(fp)
#
for i in range(batch_size):
#
# idx = np.random.randint(fp.shape[1])
#
# false_fp[i,idx] = (1 - false_fp[i,idx]) # flip
#
idx = np.random.randint(f.shape[0])
#
false_fp[i] = torch.Tensor(f[idx]) # replace
comb_t
=
torch
.
cat
([
t
,
t
],
axis
=
0
)
comb_fp
=
torch
.
cat
([
fp
,
false_fp
],
axis
=
0
)
#
comb_t = torch.cat([t,t], axis=0)
#
comb_fp = torch.cat([fp, false_fp], axis=0)
y
=
torch
.
zeros
((
batch_size
*
2
,
1
)).
cuda
()
y
[:
batch_size
]
=
1
#
y = torch.zeros((batch_size * 2,1)).cuda()
#
y[:batch_size] = 1
return
(
comb_t
,
comb_fp
,
y
,
batch_set
)
#
return (comb_t, comb_fp, y, batch_set)
def
get_batch_full
(
data
,
batch_set
=
None
,
batch_size
=
16
,
width
=
48
,
res
=
0.5
,
ignore_receptor
=
False
,
ignore_parent
=
False
):
#
def get_batch_full(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False):
assert
(
not
(
ignore_receptor
and
ignore_parent
)),
"Can't ignore parent and receptor!"
#
assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"
dim
=
18
if
ignore_receptor
or
ignore_parent
:
dim
=
9
#
dim = 18
#
if ignore_receptor or ignore_parent:
#
dim = 9
# create a tensor with shared memory on the gpu
t_context
,
grid_context
=
make_tensor
((
batch_size
,
dim
,
width
,
width
,
width
))
t_frag
,
grid_frag
=
make_tensor
((
batch_size
,
9
,
width
,
width
,
width
))
#
# create a tensor with shared memory on the gpu
#
t_context, grid_context = make_tensor((batch_size, dim, width, width, width))
#
t_frag, grid_frag = make_tensor((batch_size, 9, width, width, width))
if
batch_set
is
None
:
batch_set
=
np
.
random
.
choice
(
len
(
data
),
size
=
batch_size
,
replace
=
False
)
#
if batch_set is None:
#
batch_set = np.random.choice(len(data), size=batch_size, replace=False)
for
i
in
range
(
len
(
batch_set
)):
idx
=
batch_set
[
i
]
f_coords
,
f_types
,
p_coords
,
p_types
,
r_coords
,
r_types
,
conn
,
fp
=
data
[
idx
]
#
for i in range(len(batch_set)):
#
idx = batch_set[i]
#
f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, fp = data[idx]
# random rotation
rot
=
rand_rot
()
#
# random rotation
#
rot = rand_rot()
if
ignore_receptor
:
mol_gridify
(
grid_context
,
p_coords
,
p_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
0
)
elif
ignore_parent
:
mol_gridify
(
grid_context
,
r_coords
,
r_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
0
)
else
:
mol_gridify
(
grid_context
,
p_coords
,
p_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
0
)
mol_gridify
(
grid_context
,
r_coords
,
r_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
9
)
#
if ignore_receptor:
#
mol_gridify(grid_context, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
#
elif ignore_parent:
#
mol_gridify(grid_context, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
#
else:
#
mol_gridify(grid_context, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
#
mol_gridify(grid_context, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=9)
mol_gridify
(
grid_frag
,
f_coords
,
f_types
,
batch_i
=
i
,
center
=
conn
,
width
=
width
,
res
=
res
,
rot
=
rot
,
layer_offset
=
0
)
#
mol_gridify(grid_frag, f_coords, f_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
return
t_context
,
t_frag
,
batch_set
#
return t_context, t_frag, batch_set
def
get_raw_batch
(
r_coords
,
r_types
,
p_coords
,
p_types
,
conn
,
num_samples
=
32
,
width
=
24
,
res
=
1
,
r_dim
=
9
,
p_dim
=
9
):
...
...
leadopt/infer.py
View file @
38b5cee4
...
...
@@ -12,7 +12,7 @@ import h5py
import
tqdm
from
leadopt.grid_util
import
get_raw_batch
from
leadopt.util
import
generate_fragments