Commit 79e29be4 authored by Matt Harvey's avatar Matt Harvey
Browse files

Simply train settings.

Showing with 76 additions and 5 deletions
+76 -5
......@@ -67,13 +67,23 @@ def train(data_type, seq_length, model, saved_model=None,
def main():
"""These are the main training settings. Set each before running
this file."""
data_type = 'features' # can be 'features' or 'images'
seq_length = 40
model = 'lstm' # see `models.py` for more
class_limit = None # int, can be 1-101 or None
saved_model = None # None or weights file
concat = False # true for MLP only
image_shape = None # Use None for default or specify a new size
class_limit = None # int, can be 1-101 or None
# Chose images or features and image shape based on network.
if model == 'conv_3d' or model == 'crnn':
data_type = 'images'
image_shape = (80, 80, 3)
else:
data_type = 'features'
image_shape = None
# MLP requires flattened features.
if model == 'mlp':
concat = True
else:
concat = False
train(data_type, seq_length, model, saved_model=saved_model,
class_limit=class_limit, concat=concat, image_shape=image_shape)
......
"""
Validate our RNN. Basically just runs a validation generator on
about the same number of videos as we have in our test set.
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, CSVLogger
from models import ResearchModels
from data import DataSet
from PIL import Image
import time
def validate(data_type, model, seq_length=40, saved_model=None,
concat=False, class_limit=None, image_shape=None):
batch_size = 32
# Get the data and process it.
if image_shape is None:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit
)
else:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit,
image_shape=image_shape
)
val_generator = data.frame_generator(batch_size, 'test', data_type, concat)
# Get the model.
rm = ResearchModels(len(data.classes), model, seq_length, saved_model)
# Evaluate!
results = rm.model.evaluate_generator(
generator=val_generator,
val_samples=3200)
print(results)
print(rm.model.metrics_names)
def main():
model = 'mlp'
saved_model = 'data/ucf101/checkpoints/mlp-features.023-0.926.hdf5'
if model == 'conv_3d' or model == 'crnn':
data_type = 'images'
image_shape = (80, 80, 3)
else:
data_type = 'features'
image_shape = None
if model == 'mlp':
concat = True
else:
concat = False
validate(data_type, model, saved_model=saved_model,
concat=concat, image_shape=image_shape)
if __name__ == '__main__':
main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment