diff --git a/diaryCandidateSet.py b/diaryCandidateSet.py
index 9789684c06020693f94dffa4acdcf74fae02b73c..24a21f665eec0b17122d864d3f10daaeac443928 100644
--- a/diaryCandidateSet.py
+++ b/diaryCandidateSet.py
@@ -7,12 +7,8 @@ from config import *
 # 候选集cid只能从训练数据集cid中选择
 def filter_cid(df):
     data_set_cid = pd.read_csv(DIRECTORY_PATH + "data_set_cid.csv")["cid"].values.tolist()
-    print("过滤前样本大小:")
-    print(df.shape)
     if not df.empty:
         df = df.loc[df["cid"].isin(data_set_cid)]
-        print("过滤后样本大小:")
-        print(df.shape)
     return df
 
 
diff --git a/predictDiary.py b/predictDiary.py
index 371037a8b3430b02a8765f6eca9eda25456b6d71..5fb6ed2f69517cf125c77f98cbc76fc999641086 100644
--- a/predictDiary.py
+++ b/predictDiary.py
@@ -94,17 +94,19 @@ def router(device_id):
 
 if __name__ == "__main__":
     # TODO 如果耗时小于一分钟,下一次取到的device_id和上一次相同
-
     while True:
         start = time.time()
         empty,device_id_list = get_active_users()
         if empty:
             time.sleep(10)
         else:
+            old_device_id_list = pd.read_csv(DIRECTORY_PATH + "data_set_device_id.csv")["device_id"].values.tolist()
             for device_id in device_id_list:
-                router(device_id)
-
+                if device_id in old_device_id_list:
+                    router(device_id)
+                else:
+                    print("该用户不是老用户,不能预测")
             end = time.time()
             time_cost = (end - start)
-            print("预测耗时{}秒".format(time_cost))
+            print("耗时{}秒".format(time_cost))
 
diff --git a/processData.py b/processData.py
index 01447d4649659f1e6b48d3d12e538b3dfbd4fa46..a442657dde7cfce25193549b1920b4e01059094b 100644
--- a/processData.py
+++ b/processData.py
@@ -60,6 +60,13 @@ def feature_en():
     print(cid_df.head(2))
     cid_df.to_csv(DIRECTORY_PATH + "data_set_cid.csv", index=False)
 
+    # 将device_id 保存。目的是为了判断预测的device_id是否在这个集合里,如果不在,不需要预测
+    data_set_device_id = data["device_id"].unique()
+    device_id_df = pd.DataFrame()
+    device_id_df['device_id'] = data_set_device_id
+    print("data_set_device_id :")
+    print(device_id_df.head(2))
+    device_id_df.to_csv(DIRECTORY_PATH + "data_set_device_id.csv", index=False)
     return data, test_number, validation_number
 
 
@@ -99,8 +106,5 @@ def ffm_transform(data, test_number, validation_number):
     train.to_csv(DIRECTORY_PATH + "train{0}-{1}.csv".format(DATA_START_DATE, VALIDATION_DATE), index=False, header=None)
 
 
-if __name__ == "__main__":
-    data_fe = feature_en()
-    ffm_transform(data_fe)
 
 
diff --git a/train.py b/train.py
index 8163b5ec5b06f88ba61d460a1a01cdd46c6912a0..8ec2c5b8efb285c22b9af142b91238a5b3cbbc52 100644
--- a/train.py
+++ b/train.py
@@ -5,9 +5,10 @@ from diaryCandidateSet import get_eachCityDiaryTop3000
 
 # 把数据获取、特征转换、模型训练的模型串联在一起
 if __name__ == "__main__":
-    data_fe = feature_en()
-    ffm_transform(data_fe)
+    data, test_number, validation_number = feature_en()
+    ffm_transform(data, test_number, validation_number)
     train()
-    print('---------------prepare candidates--------------')
-    get_eachCityDiaryTop3000()
+    print("end")
+    # print('---------------prepare candidates--------------')
+    # get_eachCityDiaryTop3000()