Commit 9df6de57 authored by Matt Harvey's avatar Matt Harvey
Browse files

Add a demo script to make predictions on a single sample.

Showing with 13418 additions and 13323 deletions
+13418 -13323
......@@ -206,6 +206,33 @@ class DataSet():
else:
return None
def get_frames_by_filename(self, filename, data_type):
"""Given a filename for one of our samples, return the data
the model needs to make predictions."""
# First, find the sample row.
sample = None
for row in self.data:
if row[2] == filename:
sample = row
break
if sample is None:
raise ValueError("Couldn't find sample: %s" % filename)
if data_type == "images":
# Get and resample frames.
frames = self.get_frames_for_sample(sample)
frames = self.rescale_list(frames, self.seq_length)
# Build the image sequence
sequence = self.build_image_sequence(frames)
else:
# Get the sequence from disk.
sequence = self.get_extracted_sequence(data_type, sample)
if sequence is None:
raise ValueError("Can't find sequence. Did you generate them?")
return sequence
@staticmethod
def get_frames_for_sample(sample):
"""Given a sample row from the data file, get all the corresponding frame
......@@ -236,12 +263,11 @@ class DataSet():
# Cut off the last one if needed.
return output[:size]
@staticmethod
def print_class_from_prediction(predictions, nb_to_return=5):
def print_class_from_prediction(self, predictions, nb_to_return=5):
"""Given a prediction, print the top classes."""
# Get the prediction for each label.
label_predictions = {}
for i, label in enumerate(data.classes):
for i, label in enumerate(self.classes):
label_predictions[label] = predictions[i]
# Now sort them.
......
This diff is collapsed.
demo.py 0 → 100644
"""
Given a video path and a saved model (checkpoint), produce classification
predictions.
Note that if using a model that requires features to be extracted, those
features must be extracted first.
Note also that this is a rushed demo script to help a few people who have
requested it and so is quite "rough". :)
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, CSVLogger
from keras.models import load_model
from models import ResearchModels
from data import DataSet
import time
import os.path
import numpy as np
def predict(data_type, seq_length, saved_model, image_shape, video_name, class_limit):
model = load_model(saved_model)
# Get the data and process it.
if image_shape is None:
data = DataSet(seq_length=seq_length, class_limit=class_limit)
else:
data = DataSet(seq_length=seq_length, image_shape=image_shape,
class_limit=class_limit)
# Extract the sample from the data.
sample = data.get_frames_by_filename(video_name, data_type)
# Predict!
prediction = model.predict(np.expand_dims(sample, axis=0))
print(prediction)
print(data.print_class_from_prediction(np.squeeze(prediction, axis=0)))
def main():
# model can be one of lstm, lrcn, mlp, conv_3d, c3d.
model = 'lstm'
# Must be a weights file.
saved_model = 'data/checkpoints/lstm-features.026-0.239.hdf5'
# Sequence length must match the lengh used during training.
seq_length = 40
# Limit must match that used during training.
class_limit = 4
# Demo file. Must already be extracted & features generated (if model requires)
# Do not include the extension.
# Assumes it's in data/[train|test]/
# It also must be part of the train/test data.
# TODO Make this way more useful. It should take in the path to
# an actual video file, extract frames, generate sequences, etc.
#video_name = 'v_Archery_g04_c02'
video_name = 'v_ApplyLipstick_g01_c01'
# Chose images or features and image shape based on network.
if model in ['conv_3d', 'c3d', 'lrcn']:
data_type = 'images'
image_shape = (80, 80, 3)
elif model in ['lstm', 'mlp']:
data_type = 'features'
image_shape = None
else:
raise ValueError("Invalid model. See train.py for options.")
predict(data_type, seq_length, saved_model, image_shape, video_name, class_limit)
if __name__ == '__main__':
main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment