validate_rnn.py 1.45 KB
Newer Older
Matt Harvey's avatar
Matt Harvey committed
1
2
3
4
"""
Validate our RNN. Basically just runs a validation generator on
about the same number of videos as we have in our test set.
"""
Matt Harvey's avatar
Matt Harvey committed
5
from keras.callbacks import TensorBoard, ModelCheckpoint, CSVLogger
Matt Harvey's avatar
Matt Harvey committed
6
7
8
9
from models import ResearchModels
from data import DataSet

def validate(data_type, model, seq_length=40, saved_model=None,
10
             class_limit=None, image_shape=None):
Matt Harvey's avatar
Matt Harvey committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    batch_size = 32

    # Get the data and process it.
    if image_shape is None:
        data = DataSet(
            seq_length=seq_length,
            class_limit=class_limit
        )
    else:
        data = DataSet(
            seq_length=seq_length,
            class_limit=class_limit,
            image_shape=image_shape
        )
Matt Harvey's avatar
Matt Harvey committed
25

26
    val_generator = data.frame_generator(batch_size, 'test', data_type)
Matt Harvey's avatar
Matt Harvey committed
27
28
29
30
31
32
33
34
35
36
37
38
39

    # Get the model.
    rm = ResearchModels(len(data.classes), model, seq_length, saved_model)

    # Evaluate!
    results = rm.model.evaluate_generator(
        generator=val_generator,
        val_samples=3200)

    print(results)
    print(rm.model.metrics_names)

def main():
Matt Harvey's avatar
Minor    
Matt Harvey committed
40
41
    model = 'lstm'
    saved_model = 'data/checkpoints/lstm-features.026-0.239.hdf5'
Matt Harvey's avatar
Matt Harvey committed
42

43
    if model == 'conv_3d' or model == 'lrcn':
Matt Harvey's avatar
Matt Harvey committed
44
45
46
47
48
49
50
        data_type = 'images'
        image_shape = (80, 80, 3)
    else:
        data_type = 'features'
        image_shape = None

    validate(data_type, model, saved_model=saved_model,
Matt Harvey's avatar
Minor    
Matt Harvey committed
51
             image_shape=image_shape, class_limit=4)
Matt Harvey's avatar
Matt Harvey committed
52
53
54

if __name__ == '__main__':
    main()