From 952ab7c7f7990e650330f343be831c4e2277ffc3 Mon Sep 17 00:00:00 2001
From: Matt Harvey <harveym@gmail.com>
Date: Sun, 19 Nov 2017 21:43:34 -0800
Subject: [PATCH] Fix one-hot bug when encoding y

---
 data.py   | 5 ++---
 models.py | 4 ++--
 train.py  | 2 +-
 3 files changed, 5 insertions(+), 6 deletions(-)

diff --git a/data.py b/data.py
index 9ae0b49..e6ba0c4 100644
--- a/data.py
+++ b/data.py
@@ -10,7 +10,7 @@ import sys
 import operator
 import threading
 from processor import process_image
-from keras.utils import np_utils
+from keras.utils import to_categorical
 
 class threadsafe_iterator:
     def __init__(self, iterator):
@@ -98,8 +98,7 @@ class DataSet():
         label_encoded = self.classes.index(class_str)
 
         # Now one-hot it.
-        label_hot = np_utils.to_categorical(label_encoded, len(self.classes))
-        label_hot = label_hot[0]  # just get a single row
+        label_hot = to_categorical(label_encoded, len(self.classes))
 
         return label_hot
 
diff --git a/models.py b/models.py
index a93d30b..9bfbe1f 100644
--- a/models.py
+++ b/models.py
@@ -78,9 +78,9 @@ class ResearchModels():
         our CNN to this model predomenently."""
         # Model.
         model = Sequential()
-        model.add(LSTM(2048, return_sequences=True, input_shape=self.input_shape,
+        model.add(LSTM(2048, return_sequences=False,
+                       input_shape=self.input_shape,
                        dropout=0.5))
-        model.add(Flatten())
         model.add(Dense(512, activation='relu'))
         model.add(Dropout(0.5))
         model.add(Dense(self.nb_classes, activation='softmax'))
diff --git a/train.py b/train.py
index 78af367..f931b3f 100644
--- a/train.py
+++ b/train.py
@@ -86,7 +86,7 @@ def main():
     # model can be one of lstm, lrcn, mlp, conv_3d, c3d
     model = 'lstm'
     saved_model = None  # None or weights file
-    class_limit = 2  # int, can be 1-101 or None
+    class_limit = 10  # int, can be 1-101 or None
     seq_length = 40
     load_to_memory = False  # pre-load the sequences into memory
     batch_size = 32
-- 
GitLab