train.py 3.51 KB
Newer Older
1
"""
2
Train our RNN on extracted features or images.
3
4
5
6
7
8
9
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, CSVLogger
from models import ResearchModels
from data import DataSet
import time

def train(data_type, seq_length, model, saved_model=None,
10
11
          concat=False, class_limit=None, image_shape=None,
          load_to_memory=False):
12
    # Set variables.
Matt Harvey's avatar
Matt Harvey committed
13
    nb_epoch = 1000000
14
15
16
17
18
19
20
21
22
23
24
25
26
    batch_size = 32

    # Helper: Save the model.
    checkpointer = ModelCheckpoint(
        filepath='./data/checkpoints/' + model + '-' + data_type + \
            '.{epoch:03d}-{val_loss:.3f}.hdf5',
        verbose=1,
        save_best_only=True)

    # Helper: TensorBoard
    tb = TensorBoard(log_dir='./data/logs')

    # Helper: Stop when we stop learning.
Matt Harvey's avatar
Matt Harvey committed
27
    early_stopper = EarlyStopping(patience=100000)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    # Helper: Save results.
    timestamp = time.time()
    csv_logger = CSVLogger('./data/logs/' + model + '-' + 'training-' + \
        str(timestamp) + '.log')

    # Get the data and process it.
    if image_shape is None:
        data = DataSet(
            seq_length=seq_length,
            class_limit=class_limit
        )
    else:
        data = DataSet(
            seq_length=seq_length,
            class_limit=class_limit,
            image_shape=image_shape
        )

    # Get samples per epoch.
    # Multiply by 0.7 to attempt to guess how much of data.data is the train set.
Matt Harvey's avatar
Matt Harvey committed
49
    steps_per_epoch = (len(data.data) * 0.7) // batch_size
50

51
52
    if load_to_memory:
        # Get data.
53
54
        X, y = data.get_all_sequences_in_memory('train', data_type, concat)
        X_test, y_test = data.get_all_sequences_in_memory('test', data_type, concat)
55
56
57
58
    else:
        # Get generators.
        generator = data.frame_generator(batch_size, 'train', data_type, concat)
        val_generator = data.frame_generator(batch_size, 'test', data_type, concat)
59
60
61
62
63

    # Get the model.
    rm = ResearchModels(len(data.classes), model, seq_length, saved_model)

    # Fit!
64
65
66
67
68
69
70
71
    if load_to_memory:
        # Use standard fit.
        rm.model.fit(
            X,
            y,
            batch_size=batch_size,
            validation_data=(X_test, y_test),
            verbose=1,
72
            callbacks=[tb, early_stopper, csv_logger],
Matt Harvey's avatar
Matt Harvey committed
73
            epochs=nb_epoch)
74
75
76
77
    else:
        # Use fit generator.
        rm.model.fit_generator(
            generator=generator,
Matt Harvey's avatar
Matt Harvey committed
78
79
            steps_per_epoch=steps_per_epoch,
            epochs=nb_epoch,
80
            verbose=1,
81
            callbacks=[tb, early_stopper, csv_logger],
82
            validation_data=val_generator,
Matt Harvey's avatar
Matt Harvey committed
83
            validation_steps=10)
84
85
86
87

def main():
    """These are the main training settings. Set each before running
    this file."""
Matt Harvey's avatar
Matt Harvey committed
88
    model = 'conv_3d'  # see `models.py` for more
89
    saved_model = None  # None or weights file
Matt Harvey's avatar
Matt Harvey committed
90
    class_limit = 2  # int, can be 1-101 or None
Matt Harvey's avatar
Matt Harvey committed
91
    seq_length = 40
92
    load_to_memory = True  # pre-load the sequences into memory
Matt Harvey's avatar
Matt Harvey committed
93
94

    # Chose images or features and image shape based on network.
Matt Harvey's avatar
Matt Harvey committed
95
    if model == 'conv_3d':
Matt Harvey's avatar
Matt Harvey committed
96
        data_type = 'images'
Matt Harvey's avatar
Matt Harvey committed
97
98
99
        image_shape = (80, 80, 3)
    elif model == 'lrcn':
        data_type = 'image'
Ubuntu's avatar
Ubuntu committed
100
        image_shape = (150, 150, 3)
Matt Harvey's avatar
Matt Harvey committed
101
102
103
104
105
106
107
108
109
    else:
        data_type = 'features'
        image_shape = None

    # MLP requires flattened features.
    if model == 'mlp':
        concat = True
    else:
        concat = False
110
111

    train(data_type, seq_length, model, saved_model=saved_model,
112
113
          class_limit=class_limit, concat=concat, image_shape=image_shape,
          load_to_memory=load_to_memory)
114
115
116

if __name__ == '__main__':
    main()