Skip to content
GitLab
Explore
Projects
Groups
Snippets
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Matt Harvey
five-video-classification-methods
Commits
79e29be4
Commit
79e29be4
authored
7 years ago
by
Matt Harvey
Browse files
Options
Download
Email Patches
Plain Diff
Simply train settings.
parent
0537dd89
master
add-demo-script
github/fork/Anner-deJong/patch-2
github/fork/ideaRunner/edit_move_file
github/fork/ikvision/master
github/fork/mertia-himanshu/master
github/fork/prabindh/master
github/fork/sanshibayuan/master
playground
playground-add-synthetic
v1.0
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
train.py
+15
-5
train.py
validate_rnn.py
+61
-0
validate_rnn.py
with
76 additions
and
5 deletions
+76
-5
train.py
+
15
-
5
View file @
79e29be4
...
...
@@ -67,13 +67,23 @@ def train(data_type, seq_length, model, saved_model=None,
def
main
():
"""These are the main training settings. Set each before running
this file."""
data_type
=
'features'
# can be 'features' or 'images'
seq_length
=
40
model
=
'lstm'
# see `models.py` for more
class_limit
=
None
# int, can be 1-101 or None
saved_model
=
None
# None or weights file
concat
=
False
# true for MLP only
image_shape
=
None
# Use None for default or specify a new size
class_limit
=
None
# int, can be 1-101 or None
# Chose images or features and image shape based on network.
if
model
==
'conv_3d'
or
model
==
'crnn'
:
data_type
=
'images'
image_shape
=
(
80
,
80
,
3
)
else
:
data_type
=
'features'
image_shape
=
None
# MLP requires flattened features.
if
model
==
'mlp'
:
concat
=
True
else
:
concat
=
False
train
(
data_type
,
seq_length
,
model
,
saved_model
=
saved_model
,
class_limit
=
class_limit
,
concat
=
concat
,
image_shape
=
image_shape
)
...
...
This diff is collapsed.
Click to expand it.
validate_rnn.py
0 → 100644
+
61
-
0
View file @
79e29be4
"""
Validate our RNN. Basically just runs a validation generator on
about the same number of videos as we have in our test set.
"""
from
keras.callbacks
import
TensorBoard
,
ModelCheckpoint
,
EarlyStopping
,
CSVLogger
from
models
import
ResearchModels
from
data
import
DataSet
from
PIL
import
Image
import
time
def
validate
(
data_type
,
model
,
seq_length
=
40
,
saved_model
=
None
,
concat
=
False
,
class_limit
=
None
,
image_shape
=
None
):
batch_size
=
32
# Get the data and process it.
if
image_shape
is
None
:
data
=
DataSet
(
seq_length
=
seq_length
,
class_limit
=
class_limit
)
else
:
data
=
DataSet
(
seq_length
=
seq_length
,
class_limit
=
class_limit
,
image_shape
=
image_shape
)
val_generator
=
data
.
frame_generator
(
batch_size
,
'test'
,
data_type
,
concat
)
# Get the model.
rm
=
ResearchModels
(
len
(
data
.
classes
),
model
,
seq_length
,
saved_model
)
# Evaluate!
results
=
rm
.
model
.
evaluate_generator
(
generator
=
val_generator
,
val_samples
=
3200
)
print
(
results
)
print
(
rm
.
model
.
metrics_names
)
def
main
():
model
=
'mlp'
saved_model
=
'data/ucf101/checkpoints/mlp-features.023-0.926.hdf5'
if
model
==
'conv_3d'
or
model
==
'crnn'
:
data_type
=
'images'
image_shape
=
(
80
,
80
,
3
)
else
:
data_type
=
'features'
image_shape
=
None
if
model
==
'mlp'
:
concat
=
True
else
:
concat
=
False
validate
(
data_type
,
model
,
saved_model
=
saved_model
,
concat
=
concat
,
image_shape
=
image_shape
)
if
__name__
==
'__main__'
:
main
()
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment
Menu
Explore
Projects
Groups
Snippets