我正在尝试为TextSum模型创建自己的训练数据。 根据我的理解,我需要将我的文章和摘要放到二进制文件中(在TFRecords中)。 但是,我无法从原始文本文件创建自己的训练数据。 我不明白格式,所以我试图用下面的代码创建一个非常简单的二进制文件:
files = os.listdir(path) writer = tf.python_io.TFRecordWriter("test_data") for i, file in enumerate(files): content = open(os.path.join(path, file), "r").read() example = tf.train.Example( features = tf.train.Features( feature = { 'content': tf.train.Feature(bytes_list=tf.train.BytesList(value=[content])) } ) ) serialized = example.SerializeToString() writer.write(serialized)我尝试使用以下代码来读出此test_data文件的值
reader = open("test_data", 'rb') len_bytes = reader.read(8) str_len = struct.unpack('q', len_bytes)[0] example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] example_pb2.Example.FromString(example_str)但我总是得到以下错误:
File "dailymail_corpus_to_tfrecords.py", line 34, in check_file example_pb2.Example.FromString(example_str) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 770, in FromString message.MergeFromString(s) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString if self._InternalParse(serialized, 0, length) != length: File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1117, in InternalParse new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 850, in SkipField return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 791, in _SkipLengthDelimited raise _DecodeError('Truncated message.') google.protobuf.message.DecodeError: Truncated message.我不知道什么是错的。 如果您有任何解决此问题的建议,请让我知道。
I am trying to create my own training data for TextSum model. As my understanding, I need to put my articles and abstracts to a binary file (in TFRecords). However, I can not create my own training data from raw text files. I don't understand format very clearly, so I am trying to create a very simple binary file using the following code:
files = os.listdir(path) writer = tf.python_io.TFRecordWriter("test_data") for i, file in enumerate(files): content = open(os.path.join(path, file), "r").read() example = tf.train.Example( features = tf.train.Features( feature = { 'content': tf.train.Feature(bytes_list=tf.train.BytesList(value=[content])) } ) ) serialized = example.SerializeToString() writer.write(serialized)And I try to use the following code to read out the value of this test_data file
reader = open("test_data", 'rb') len_bytes = reader.read(8) str_len = struct.unpack('q', len_bytes)[0] example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] example_pb2.Example.FromString(example_str)But I always get the following error:
File "dailymail_corpus_to_tfrecords.py", line 34, in check_file example_pb2.Example.FromString(example_str) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 770, in FromString message.MergeFromString(s) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString if self._InternalParse(serialized, 0, length) != length: File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1117, in InternalParse new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 850, in SkipField return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 791, in _SkipLengthDelimited raise _DecodeError('Truncated message.') google.protobuf.message.DecodeError: Truncated message.I have no idea what is wrong. Please let me know if you have any suggestions to solve this issue.
最满意答案
对于那些有同样问题的人。 我不得不看看TensorFlow的源代码,看看他们如何用TFRecordWriter写出数据。 我意识到他们实际上写8个字节的长度,4个字节的CRC校验,这意味着前12个字节是用于标题。 因为在TextSum代码中,示例二进制文件似乎只有8字节的标头,这就是为什么他们使用reader.read(8)来获取数据的长度并将其余部分作为特征读取。
我的工作解决方案是:
reader = open("test_data", 'rb') len_bytes = reader.read(8) reader.read(4) #ignore next 4 bytes str_len = struct.unpack('q', len_bytes)[0] example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] example_pb2.Example.FromString(example_str)For those who have the same issue. I had to look at the source code of TensorFlow to see how they write out the data with TFRecordWriter. I've realized that they actually write 8 bytes for length, 4 bytes for CRC check, it means that the first 12 bytes are for header. Because in TextSum code, the sample binary file seems to have only 8-byte header, that's why they use reader.read(8) to get the length of the data and read the rest as features.
My working solution is:
reader = open("test_data", 'rb') len_bytes = reader.read(8) reader.read(4) #ignore next 4 bytes str_len = struct.unpack('q', len_bytes)[0] example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] example_pb2.Example.FromString(example_str)更多推荐
发布评论