Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
jdurrant
deepfrag
Commits
4d81a065
Commit
4d81a065
authored
Jan 06, 2021
by
jdurrant
Browse files
Added license text to files.
parent
d6e0f779
Pipeline
#323
failed with stages
in 0 seconds
Changes
20
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
LICENSE.md
View file @
4d81a065
...
...
@@ -172,27 +172,3 @@ any liability incurred by, or claims asserted against, such Contributor by
reason of your accepting any such warranty or additional liability.
_END OF TERMS AND CONDITIONS_
### APPENDIX: How to apply the Apache License to your work
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets
`[]`
replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification
within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
config/__init__.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
config/moad_partitions.py
View file @
4d81a065
This diff is collapsed.
Click to expand it.
leadopt/__init__.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
leadopt/data_util.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Contains utility code for reading packed data files.
"""
...
...
@@ -10,15 +25,15 @@ import h5py
import
tqdm
# Atom typing
#
#
# Atom typing is the process of figuring out which layer each atom should be
# written to. For ease of testing, the packed data file contains a lot of
# potentially useful atomic information which can be distilled during the
# data loading process.
#
#
# Atom typing is implemented by map functions of the type:
# (atom descriptor) -> (layer index)
#
#
# If the layer index is -1, the atom is ignored.
...
...
@@ -139,13 +154,13 @@ LIG_TYPER = {
class
FragmentDataset
(
Dataset
):
"""Utility class to work with the packed fragments.h5 format."""
def
__init__
(
self
,
fragment_file
,
rec_typer
=
REC_TYPER
[
'simple'
],
lig_typer
=
LIG_TYPER
[
'simple'
],
filter_rec
=
None
,
filter_smi
=
None
,
fdist_min
=
None
,
fdist_max
=
None
,
fmass_min
=
None
,
fmass_max
=
None
,
fdist_min
=
None
,
fdist_max
=
None
,
fmass_min
=
None
,
fmass_max
=
None
,
verbose
=
False
,
lazy_loading
=
True
):
"""Initializes the fragment dataset.
Args:
fragment_file: path to fragments.h5
rec_typer: AtomTyper for receptor
...
...
@@ -174,11 +189,11 @@ class FragmentDataset(Dataset):
def
_load_rec
(
self
,
fragment_file
,
rec_typer
):
"""Loads receptor information."""
f
=
h5py
.
File
(
fragment_file
,
'r'
)
rec_coords
=
f
[
'rec_coords'
][()]
rec_types
=
f
[
'rec_types'
][()]
rec_lookup
=
f
[
'rec_lookup'
][()]
r
=
range
(
len
(
rec_types
))
if
self
.
verbose
:
r
=
tqdm
.
tqdm
(
r
,
desc
=
'Remap receptor atoms'
)
...
...
@@ -194,7 +209,7 @@ class FragmentDataset(Dataset):
rec_mapping
=
{}
for
i
in
range
(
len
(
rec_lookup
)):
rec_mapping
[
rec_lookup
[
i
][
0
].
decode
(
'ascii'
)]
=
i
rec
=
{
'rec_coords'
:
rec_coords
,
'rec_types'
:
rec_types
,
...
...
@@ -203,11 +218,11 @@ class FragmentDataset(Dataset):
'rec_mapping'
:
rec_mapping
,
'rec_loaded'
:
rec_loaded
}
f
.
close
()
return
rec
def
_load_fragments
(
self
,
fragment_file
,
lig_typer
):
"""Loads fragment information."""
f
=
h5py
.
File
(
fragment_file
,
'r'
)
...
...
@@ -227,12 +242,12 @@ class FragmentDataset(Dataset):
# unpack frag data into separate structures
frag_coords
=
frag_data
[:,:
3
].
astype
(
np
.
float32
)
frag_types
=
frag_data
[:,
3
].
astype
(
np
.
uint8
)
frag_remapped
=
np
.
zeros
(
len
(
frag_types
),
dtype
=
np
.
uint16
)
if
not
self
.
_lazy_loading
:
for
i
in
range
(
len
(
frag_types
)):
frag_remapped
[
i
]
=
lig_typer
.
apply
(
frag_types
[
i
])
frag_loaded
=
np
.
zeros
(
len
(
frag_lookup
)).
astype
(
np
.
bool
)
# find and save connection point
...
...
@@ -244,7 +259,7 @@ class FragmentDataset(Dataset):
for
i
in
r
:
_
,
f_start
,
f_end
,
_
,
_
=
frag_lookup
[
i
]
fdat
=
frag_data
[
f_start
:
f_end
]
found
=
False
for
j
in
range
(
len
(
fdat
)):
if
fdat
[
j
][
3
]
==
0
:
...
...
@@ -253,7 +268,7 @@ class FragmentDataset(Dataset):
break
assert
found
,
"missing fragment connection point at %d"
%
i
frag
=
{
'frag_coords'
:
frag_coords
,
# d_idx -> (x,y,z)
'frag_types'
:
frag_types
,
# d_idx -> (type)
...
...
@@ -267,7 +282,7 @@ class FragmentDataset(Dataset):
'frag_lig_idx'
:
frag_lig_idx
,
'frag_loaded'
:
frag_loaded
}
f
.
close
()
return
frag
...
...
@@ -275,14 +290,14 @@ class FragmentDataset(Dataset):
def
_get_valid_examples
(
self
,
filter_rec
,
filter_smi
,
fdist_min
,
fdist_max
,
fmass_min
,
fmass_max
,
verbose
):
"""Returns an array of valid fragment indexes.
"Valid" in this context means the fragment belongs to a receptor in
filter_rec and the fragment abides by the optional mass/distance
constraints.
"""
# keep track of valid examples
valid_mask
=
np
.
ones
(
self
.
frag
[
'frag_lookup'
].
shape
[
0
]).
astype
(
np
.
bool
)
num_frags
=
self
.
frag
[
'frag_lookup'
].
shape
[
0
]
# filter by receptor id
...
...
@@ -298,7 +313,7 @@ class FragmentDataset(Dataset):
if
rec
in
filter_rec
:
valid_rec
[
i
]
=
1
valid_mask
*=
valid_rec
# filter by ligand smiles string
if
filter_smi
is
not
None
:
valid_lig
=
np
.
zeros
(
num_frags
,
dtype
=
np
.
bool
)
...
...
@@ -318,10 +333,10 @@ class FragmentDataset(Dataset):
# filter by fragment distance
if
fdist_min
is
not
None
:
valid_mask
[
self
.
frag
[
'frag_dist'
]
<
fdist_min
]
=
0
if
fdist_max
is
not
None
:
valid_mask
[
self
.
frag
[
'frag_dist'
]
>
fdist_max
]
=
0
# filter by fragment mass
if
fmass_min
is
not
None
:
valid_mask
[
self
.
frag
[
'frag_mass'
]
<
fmass_min
]
=
0
...
...
@@ -337,10 +352,10 @@ class FragmentDataset(Dataset):
def
__len__
(
self
):
"""Returns the number of valid fragment examples."""
return
self
.
valid_idx
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Returns the Nth example.
Returns a dict with:
f_coords: fragment coordinates (Fx3)
f_types: fragment layers (Fx1)
...
...
@@ -354,23 +369,23 @@ class FragmentDataset(Dataset):
# convert to fragment index
frag_idx
=
self
.
valid_idx
[
idx
]
return
self
.
get_raw
(
frag_idx
)
def
get_raw
(
self
,
frag_idx
):
# lookup fragment
rec_id
,
f_start
,
f_end
,
p_start
,
p_end
=
self
.
frag
[
'frag_lookup'
][
frag_idx
]
smiles
=
self
.
frag
[
'frag_smiles'
][
frag_idx
].
decode
(
'ascii'
)
conn
=
self
.
frag
[
'frag_conn'
][
frag_idx
]
# lookup receptor
rec_idx
=
self
.
rec
[
'rec_mapping'
][
rec_id
.
decode
(
'ascii'
)]
_
,
r_start
,
r_end
=
self
.
rec
[
'rec_lookup'
][
rec_idx
]
# fetch data
# f_coords = self.frag['frag_coords'][f_start:f_end]
# f_types = self.frag['frag_types'][f_start:f_end]
p_coords
=
self
.
frag
[
'frag_coords'
][
p_start
:
p_end
]
r_coords
=
self
.
rec
[
'rec_coords'
][
r_start
:
r_end
]
if
self
.
_lazy_loading
and
self
.
frag
[
'frag_loaded'
][
frag_idx
]
==
0
:
frag_types
=
self
.
frag
[
'frag_types'
]
frag_remapped
=
self
.
frag
[
'frag_remapped'
]
...
...
@@ -468,29 +483,29 @@ class FingerprintDataset(Dataset):
def
_load_fingerprints
(
self
,
fingerprint_file
):
"""Loads fingerprint information."""
f
=
h5py
.
File
(
fingerprint_file
,
'r'
)
fingerprint_data
=
f
[
'fingerprints'
][()]
fingerprint_smiles
=
f
[
'smiles'
][()]
# create smiles->idx mapping
fingerprint_mapping
=
{}
for
i
in
range
(
len
(
fingerprint_smiles
)):
sm
=
fingerprint_smiles
[
i
].
decode
(
'ascii'
)
fingerprint_mapping
[
sm
]
=
i
fingerprints
=
{
'fingerprint_data'
:
fingerprint_data
,
'fingerprint_mapping'
:
fingerprint_mapping
,
'fingerprint_smiles'
:
fingerprint_smiles
,
}
f
.
close
()
return
fingerprints
def
for_smiles
(
self
,
smiles
):
"""Return a Tensor of fingerprints for a list of smiles.
Args:
smiles: size N list of smiles strings (as str not bytes)
"""
...
...
leadopt/grid_util.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Contains code for gpu-accelerated grid generation.
"""
...
...
@@ -120,15 +135,15 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
# invisible atoms
if
mask
==
0
:
continue
# point radius squared
r
=
point_radius
r2
=
point_radius
*
point_radius
# quick cube bounds check
if
abs
(
fx
-
tx
)
>
r2
or
abs
(
fy
-
ty
)
>
r2
or
abs
(
fz
-
tz
)
>
r2
:
continue
# value to add to this gridpoint
val
=
0
...
...
@@ -147,7 +162,7 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
d2
=
(
fx
-
tx
)
**
2
+
(
fy
-
ty
)
**
2
+
(
fz
-
tz
)
**
2
if
d2
>
r2
:
continue
val
=
1
elif
point_type
==
2
:
# POINT_TYPE.CUBE
# solid cube fill
...
...
@@ -290,7 +305,7 @@ def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5,
rot
=
fixed_rot
if
rot
is
None
:
rot
=
rand_rot
()
if
ignore_receptor
:
mol_gridify
(
cuda_grid
,
...
...
@@ -420,5 +435,5 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
point_type
=
point_type
,
acc_type
=
acc_type
)
return
torch_grid
leadopt/infer.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import
os
...
...
leadopt/metrics.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import
torch
import
torch.nn
as
nn
...
...
@@ -36,7 +50,7 @@ def broadcast_fn(fn, yp, yt):
def
average_position
(
fingerprints
,
fn
,
norm
=
True
):
"""Returns the average ranking of the correct fragment relative to all
possible fragments.
Args:
fingerprints: NxF tensor of fingerprint data
fn: distance function to compare fingerprints
...
...
@@ -54,7 +68,7 @@ def average_position(fingerprints, fn, norm=True):
# number of fragment that are closer or equal
count
=
torch
.
sum
((
dist
<=
p_dist
[
i
]).
to
(
torch
.
float
))
c
[
i
]
=
count
score
=
torch
.
mean
(
c
)
return
score
...
...
@@ -148,7 +162,7 @@ def top_k_acc(fingerprints, fn, k, pre=''):
for
j
in
range
(
len
(
k
)):
c
[
i
,
j
]
=
int
(
count
<
k
[
j
])
score
=
torch
.
mean
(
c
,
0
)
m
=
{
'%sacc_%d'
%
(
pre
,
h
):
v
.
item
()
for
h
,
v
in
zip
(
k
,
score
)}
...
...
leadopt/model_conf.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import
os
import
json
...
...
@@ -59,7 +73,7 @@ LOSS_TYPE = {
class
RunLog
(
object
):
def
__init__
(
self
,
args
,
models
,
wandb_project
=
None
):
"""Initialize a run logger.
Args:
args: command line training arguments
models: {name: model} mapping
...
...
@@ -71,7 +85,7 @@ class RunLog(object):
project
=
wandb_project
,
config
=
args
)
for
m
in
models
:
wandb
.
watch
(
models
[
m
])
...
...
@@ -131,7 +145,7 @@ class LeadoptModel(object):
@
staticmethod
def
setup_parser
(
sub
):
"""Adds arguments to a subparser.
Args:
sub: an argparse subparser
"""
...
...
@@ -147,7 +161,7 @@ class LeadoptModel(object):
Call LeadoptModel.load to infer model type.
Or call subclass.load to load a specific model type.
Args:
path: full path to saved model
"""
...
...
@@ -181,7 +195,7 @@ class LeadoptModel(object):
def
save
(
self
,
path
):
"""Save model configuration to a path.
Args:
path: path to an existing directory to save models
"""
...
...
@@ -303,9 +317,9 @@ class VoxelNet(LeadoptModel):
# # partitions.TRAIN if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.TRAIN)),
# filter_smi=set(moad_partitions.TRAIN_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )
...
...
@@ -318,9 +332,9 @@ class VoxelNet(LeadoptModel):
# # partitions.VAL if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.VAL)),
# filter_smi=set(moad_partitions.VAL_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )
...
...
@@ -336,9 +350,9 @@ class VoxelNet(LeadoptModel):
dat
,
filter_rec
=
set
(
get_bios
(
moad_partitions
.
TRAIN
)),
filter_smi
=
set
(
moad_partitions
.
TRAIN_SMI
),
fdist_min
=
self
.
_args
[
'fdist_min'
],
fdist_max
=
self
.
_args
[
'fdist_max'
],
fmass_min
=
self
.
_args
[
'fmass_min'
],
fdist_min
=
self
.
_args
[
'fdist_min'
],
fdist_max
=
self
.
_args
[
'fdist_max'
],
fmass_min
=
self
.
_args
[
'fmass_min'
],
fmass_max
=
self
.
_args
[
'fmass_max'
],
)
...
...
@@ -346,9 +360,9 @@ class VoxelNet(LeadoptModel):
dat
,
filter_rec
=
set
(
get_bios
(
moad_partitions
.
VAL
)),
filter_smi
=
set
(
moad_partitions
.
VAL_SMI
),
fdist_min
=
self
.
_args
[
'fdist_min'
],
fdist_max
=
self
.
_args
[
'fdist_max'
],
fmass_min
=
self
.
_args
[
'fmass_min'
],
fdist_min
=
self
.
_args
[
'fdist_min'
],
fdist_max
=
self
.
_args
[
'fdist_max'
],
fmass_min
=
self
.
_args
[
'fmass_min'
],
fmass_max
=
self
.
_args
[
'fmass_max'
],
)
...
...
@@ -391,7 +405,7 @@ class VoxelNet(LeadoptModel):
self
.
_models
[
'voxel'
].
parameters
(),
lr
=
self
.
_args
[
'learning_rate'
])
steps_per_epoch
=
len
(
train_dat
)
//
self
.
_args
[
'batch_size'
]
steps_per_epoch
=
custom_steps
if
custom_steps
is
not
None
else
steps_per_epoch
# configure metrics
dist_fn
=
DIST_FN
[
self
.
_args
[
'dist_fn'
]]
...
...
@@ -505,9 +519,9 @@ class VoxelNet(LeadoptModel):
# filter_rec=partitions.TEST,
filter_rec
=
set
(
get_bios
(
moad_partitions
.
VAL
if
use_val
else
moad_partitions
.
TEST
)),
filter_smi
=
set
(
moad_partitions
.
VAL_SMI
if
use_val
else
moad_partitions
.
TEST_SMI
),
fdist_min
=
self
.
_args
[
'fdist_min'
],
fdist_max
=
self
.
_args
[
'fdist_max'
],
fmass_min
=
self
.
_args
[
'fmass_min'
],
fdist_min
=
self
.
_args
[
'fdist_min'
],
fdist_max
=
self
.
_args
[
'fdist_max'
],
fmass_min
=
self
.
_args
[
'fmass_min'
],
fmass_max
=
self
.
_args
[
'fmass_max'
],
verbose
=
True
)
...
...
@@ -523,7 +537,7 @@ class VoxelNet(LeadoptModel):
smiles
=
[
test_dat
[
i
][
'smiles'
]
for
i
in
range
(
len
(
test_dat
))]
correct_fp
=
fingerprints
.
for_smiles
(
smiles
).
numpy
()
# (example_idx, sample_idx)
queries
=
[]
for
i
in
range
(
len
(
test_dat
)):
...
...
leadopt/models/__init__.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
leadopt/models/backport.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from
torch.nn
import
Module
from
typing
import
Tuple
,
Union
...
...
leadopt/models/voxel.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import
torch
import
torch.nn
as
nn
...
...
@@ -56,12 +70,12 @@ class VoxelFingerprintNet(nn.Module):
self
.
pred
=
nn
.
Sequential
(
*
pred
)
self
.
norm
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
):
for
b
in
self
.
blocks
:
x
=
b
(
x
)
x
=
self
.
reduce
(
x
)
x
=
self
.
pred
(
x
)
x
=
self
.
norm
(
x
)
return
x
leadopt/util.py
View file @
4d81a065
# Copyright 2021 Jacob Durrant