train.py 3.57 KB
Newer Older
1
"""
2
Train our RNN on extracted features or images.
3
4
5
6
7
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, CSVLogger
from models import ResearchModels
from data import DataSet
import time
Matt Harvey's avatar
Matt Harvey committed
8
import os.path
9
10

def train(data_type, seq_length, model, saved_model=None,
11
12
          concat=False, class_limit=None, image_shape=None,
          load_to_memory=False):
13
    # Set variables.
Matt Harvey's avatar
Matt Harvey committed
14
    nb_epoch = 1000000
15
16
17
18
    batch_size = 32

    # Helper: Save the model.
    checkpointer = ModelCheckpoint(
Matt Harvey's avatar
Matt Harvey committed
19
20
        filepath=os.path.join('data', 'checkpoints', + model + '-' + data_type + \
            '.{epoch:03d}-{val_loss:.3f}.hdf5'),
21
22
23
24
        verbose=1,
        save_best_only=True)

    # Helper: TensorBoard
Matt Harvey's avatar
Matt Harvey committed
25
    tb = TensorBoard(log_dir=os.path.join('data', 'logs')
26
27

    # Helper: Stop when we stop learning.
Matt Harvey's avatar
Matt Harvey committed
28
    early_stopper = EarlyStopping(patience=100000)
29
30
31

    # Helper: Save results.
    timestamp = time.time()
Matt Harvey's avatar
Matt Harvey committed
32
33
    csv_logger = CSVLogger(os.path.join('data', 'logs', + model + '-' + 'training-' + \
        str(timestamp) + '.log'))
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

    # Get the data and process it.
    if image_shape is None:
        data = DataSet(
            seq_length=seq_length,
            class_limit=class_limit
        )
    else:
        data = DataSet(
            seq_length=seq_length,
            class_limit=class_limit,
            image_shape=image_shape
        )

    # Get samples per epoch.
    # Multiply by 0.7 to attempt to guess how much of data.data is the train set.
Matt Harvey's avatar
Matt Harvey committed
50
    steps_per_epoch = (len(data.data) * 0.7) // batch_size
51

52
53
    if load_to_memory:
        # Get data.
54
55
        X, y = data.get_all_sequences_in_memory('train', data_type, concat)
        X_test, y_test = data.get_all_sequences_in_memory('test', data_type, concat)
56
57
58
59
    else:
        # Get generators.
        generator = data.frame_generator(batch_size, 'train', data_type, concat)
        val_generator = data.frame_generator(batch_size, 'test', data_type, concat)
60
61
62
63
64

    # Get the model.
    rm = ResearchModels(len(data.classes), model, seq_length, saved_model)

    # Fit!
65
66
67
68
69
70
71
72
    if load_to_memory:
        # Use standard fit.
        rm.model.fit(
            X,
            y,
            batch_size=batch_size,
            validation_data=(X_test, y_test),
            verbose=1,
73
            callbacks=[tb, early_stopper, csv_logger],
Matt Harvey's avatar
Matt Harvey committed
74
            epochs=nb_epoch)
75
76
77
78
    else:
        # Use fit generator.
        rm.model.fit_generator(
            generator=generator,
Matt Harvey's avatar
Matt Harvey committed
79
80
            steps_per_epoch=steps_per_epoch,
            epochs=nb_epoch,
81
            verbose=1,
82
            callbacks=[tb, early_stopper, csv_logger],
83
            validation_data=val_generator,
Matt Harvey's avatar
Matt Harvey committed
84
            validation_steps=10)
85
86
87
88

def main():
    """These are the main training settings. Set each before running
    this file."""
Matt Harvey's avatar
Matt Harvey committed
89
    model = 'lrcn'  # see `models.py` for more
90
    saved_model = None  # None or weights file
Matt Harvey's avatar
Matt Harvey committed
91
    class_limit = 2  # int, can be 1-101 or None
Matt Harvey's avatar
Matt Harvey committed
92
    seq_length = 40
93
    load_to_memory = True  # pre-load the sequences into memory
Matt Harvey's avatar
Matt Harvey committed
94
95

    # Chose images or features and image shape based on network.
Matt Harvey's avatar
Matt Harvey committed
96
    if model == 'conv_3d':
Matt Harvey's avatar
Matt Harvey committed
97
        data_type = 'images'
Matt Harvey's avatar
Matt Harvey committed
98
99
        image_shape = (80, 80, 3)
    elif model == 'lrcn':
Matt Harvey's avatar
Matt Harvey committed
100
        data_type = 'images'
Ubuntu's avatar
Ubuntu committed
101
        image_shape = (150, 150, 3)
Matt Harvey's avatar
Matt Harvey committed
102
103
104
105
106
107
108
109
110
    else:
        data_type = 'features'
        image_shape = None

    # MLP requires flattened features.
    if model == 'mlp':
        concat = True
    else:
        concat = False
111
112

    train(data_type, seq_length, model, saved_model=saved_model,
113
114
          class_limit=class_limit, concat=concat, image_shape=image_shape,
          load_to_memory=load_to_memory)
115
116
117

if __name__ == '__main__':
    main()