Table of Contents
TensorFlow 导入数据的基本机制有两种,一种是利用 placeholder
,一种是将数据读取也作为一个OP。
如果在GPU上执行计算图的计算,那么前一种方式会因为从CPU拷贝数据到GPU而堵塞,后一种方式可以利用多线程提升性能。
TFRecordDataset
TFRecordDataset
是Tensorflow中标准格式,
TFRecordDataset的生成
如果利用Spark可以利用 spark-tensorflow-connector,
可以直接下载编译好的jar包,通过 --jars path-to-spark-tensorflow-connector.jar
导入这个包,然后就可以直接将DataFrame
转换成分布式TFRecord文件。
不过这个包要 Spark2.0 以上才支持。
import org.apache.commons.io.FileUtils import org.apache.spark.sql.{ DataFrame, Row } import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ val path = "test-output.tfrecord" val testRows: Array[Row] = Array( new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) val schema = StructType(List(StructField("id", IntegerType), StructField("IntegerCol", IntegerType), StructField("LongCol", LongType), StructField("FloatCol", FloatType), StructField("DoubleCol", DoubleType), StructField("VectorCol", ArrayType(DoubleType, true)), StructField("StringCol", StringType))) val rdd = spark.sparkContext.parallelize(testRows) //Save DataFrame as TFRecords val df: DataFrame = spark.createDataFrame(rdd, schema) df.write.format("tfrecords").option("recordType", "Example").save(path) //Read TFRecords into DataFrame. //The DataFrame schema is inferred from the TFRecords if no custom schema is provided. val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path) importedDf1.show() //Read TFRecords into DataFrame using custom schema val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path) importedDf2.show()
提交命令
## jar包在本地 $SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector_2.11-1.10.0.jar ## jar包在HDFS上 $SPARK_HOME/bin/spark-shell --jars viewfs://xxxx/path/to/spark/tensorflow/connector.jar
Dataset读取
# 定义转换函数,输入时序列化的 def parse_tfrecords_function(example_proto): features = { 'label' : tf.FixedLenFeature([1], tf.float32), 'idx' : tf.VarLenFeature(tf.int64), 'value' : tf.VarLenFeature(tf.float32) } # tf.parse_single_example 将序列化的数据解码成张量 parsed_features = tf.parse_single_example(example_proto, features) return parsed_features["idx"], parsed_features["value"], parsed_features['label'] dataset = tf.data.TFRecordDataset(filenames) # map, repeat, prefetch, shuffle, batch 都是可选的,可以直接 make_initializable_iterator it = dataset.map(parse_tfrecords_function).repeat(10).prefetch(10240).shuffle(10240).batch(128).make_initializable_iterator() idx_op, value_op, label_op = it.get_next() # 需要初始化 sess.run(it.initializer)