From a0be2abc81a48646a23280f2947ef1f9ce99fecb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=83=AD=E7=BE=BD?= <guoyu@igengmei.com>
Date: Wed, 26 May 2021 19:26:47 +0800
Subject: [PATCH] =?UTF-8?q?=E7=BE=8E=E8=B4=AD=E7=B2=BE=E6=8E=92=E6=A8=A1?=
 =?UTF-8?q?=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 mlp/train.py | 24 ++++++++++++------------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlp/train.py b/mlp/train.py
index d73ae4e..d8cd1a8 100644
--- a/mlp/train.py
+++ b/mlp/train.py
@@ -39,18 +39,18 @@ def getDataVocabFromRedis(version):
     return dataVocab
 
 # 数据类型转换
-def csvTypeConvert(df,data_vocab):
-    # 离散na值填充
-    for k, v in data_vocab.items():
-        df[k] = df[k].fillna("-1")
-        df[k] = df[k].astype("string")
-
-    for k in ITEM_NUMBER_COLUMNS:
-        df[k] = df[k].fillna(0.0)
-        df[k] = df[k].astype("float")
+def csvTypeConvert(columns,df,data_vocab):
+    for k in columns:
+        # 离散na值填充
+        if k in data_vocab.items():
+            df[k] = df[k].astype("string")
+            df[k] = df[k].fillna("-1")
+        else:
+            df[k] = df[k].astype("float")
+            df[k] = df[k].fillna(0.0)
 
     df["label"] = df["label"].astype("int")
-    print(df.dtypes)
+    # print(df.dtypes)
     return df
 
 def loadData(data_path):
@@ -166,8 +166,8 @@ if __name__ == '__main__':
     print("trainSize:{},testSize{}".format(trainSize,testSize))
 
     # 数据类型转换
-    df_train = csvTypeConvert(df_train,data_vocab)
-    df_test = csvTypeConvert(df_test,data_vocab)
+    df_train = csvTypeConvert(datasColumns,df_train,data_vocab)
+    df_test = csvTypeConvert(datasColumns,df_test,data_vocab)
 
     # 获取训练数据
     train_data = getDataSet(df_train,shuffleSize=trainSize,)
-- 
2.18.0