Commit 22d0a85e authored by 宋柯's avatar 宋柯

模型调试

parent 3963e683
......@@ -1009,7 +1009,7 @@ if __name__ == '__main__':
elif field.startswith(ITEM_PREFIX + NUMERIC_PREFIX) or field.startswith(USER_PREFIX + NUMERIC_PREFIX):
fields_na_value_dict[field] = 0
samples = samples.na.fill(fields_na_value_dict).coalesce(1)
samples = samples.na.fill(fields_na_value_dict)
samples.printSchema()
......@@ -1026,26 +1026,25 @@ if __name__ == '__main__':
output_file = "file:///home/gmuser/" + categoty_field + "_vocab"
output_file = "/strategy/" + categoty_field + "_vocab"
# train_samples.select(categoty_field).where(F.col(categoty_field) != '-1').where(F.col(categoty_field) != '').distinct().write.mode("overwrite").options(header="false").csv(output_file)
categoty_field_rows = train_samples.select(categoty_field).where(F.col(categoty_field) != '-1').where(F.col(categoty_field) != '').distinct().collect()
categoty_field_rows = train_samples.select(categoty_field).where(F.col(categoty_field) != '-1').where(F.col(categoty_field) != '').distinct().coalesce(1).collect()
vocab_redis_keys.append("strategy:" + categoty_field + ":vocab")
saveVocab(vocab_redis_keys[-1], list(map(lambda row: row[categoty_field], categoty_field_rows)))
for multi_categoty_field in multi_categoty_fields:
output_file = "file:///home/gmuser/" + multi_categoty_field + "_vocab"
output_file = "/strategy/" + multi_categoty_field + "_vocab"
# train_samples.selectExpr("explode(split({multi_categoty_field},','))".format(multi_categoty_field = multi_categoty_field)).where(F.col(multi_categoty_field) != '-1').distinct().write.mode("overwrite").options(header="false").csv(output_file)
multi_categoty_field_rows = train_samples.selectExpr("explode(split({multi_categoty_field},',')) as {multi_categoty_field}".format(multi_categoty_field = multi_categoty_field)).where(F.col(multi_categoty_field) != '-1').where(F.col(multi_categoty_field) != '').distinct().collect()
multi_categoty_field_rows = train_samples.selectExpr("explode(split({multi_categoty_field},',')) as {multi_categoty_field}".format(multi_categoty_field = multi_categoty_field)).where(F.col(multi_categoty_field) != '-1').where(F.col(multi_categoty_field) != '').distinct().coalesce(1).collect()
vocab_redis_keys.append("strategy:" + multi_categoty_field + ":vocab")
saveVocab(vocab_redis_keys[-1], list(map(lambda row: row[multi_categoty_field], multi_categoty_field_rows)))
saveVocab("strategy:all:vocab", vocab_redis_keys)
output_file = "file:///home/gmuser/train_samples"
output_file = "/strategy/train_samples"
train_samples.write.mode("overwrite").options(header="false", sep='|').csv(output_file)
import tensorflow as tf
def get_example_string(line):
splits = line.split('|')
def get_example_string(splits):
def fill_null(str):
return '-1' if str is None else str
features = {
'ITEM_CATEGORY_card_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[0].encode()])),
'USER_CATEGORY_device_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[2].encode()])),
......@@ -1053,13 +1052,13 @@ if __name__ == '__main__':
'USER_CATEGORY_user_city_id': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[4].encode()])),
'USER_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[6].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[6]).split(','))))),
'USER_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[7].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[7]).split(','))))),
'USER_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[8].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[8]).split(','))))),
'USER_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[9].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[9]).split(','))))),
'ITEM_NUMERIC_click_count_sum': tf.train.Feature(
float_list=tf.train.FloatList(value=[float(splits[10])])),
'ITEM_NUMERIC_click_count_avg': tf.train.Feature(
......@@ -1093,13 +1092,13 @@ if __name__ == '__main__':
'ITEM_CATEGORY_hospital_is_high_quality': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[27].encode()])),
'ITEM_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[28].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[28]).split(','))))),
'ITEM_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[29].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[29]).split(','))))),
'ITEM_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[30].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[30]).split(','))))),
'ITEM_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[31].split(','))))),
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), fill_null(splits[31]).split(','))))),
'ITEM_NUMERIC_sku_price': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[32])])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(splits[5])])),
}
......@@ -1114,14 +1113,16 @@ if __name__ == '__main__':
return tf_serialized
output_file = "file:///home/gmuser/train_samples"
output_file = "/strategy/train_samples_tfrecord"
os.system("hdfs dfs -rmr {output_file}".format(output_file = output_file))
train_samples.rdd.map(get_example_string).coalesce(1).saveAsTextFile(output_file)
output_file = "file:///home/gmuser/eval_samples"
output_file = "/strategy/eval_samples"
test_samples.write.mode("overwrite").options(header="false", sep='|').csv(output_file)
output_file = "/strategy/eval_samples_tfrecord"
os.system("hdfs dfs -rmr {output_file}".format(output_file=output_file))
test_samples.rdd.map(get_example_string).coalesce(1).saveAsTextFile(output_file)
print("训练数据写入 耗时s:{}".format(time.time() - write_time_start))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment