From c336d215edc718471424cb13c60bedcc8db0180e Mon Sep 17 00:00:00 2001 From: Allen Short Date: Sun, 11 Mar 2012 15:44:42 -0500 Subject: [PATCH] Add named elements for meta fields as an alternative to functions, in order to allow development of compiled parsers. --- construct/__init__.py | 6 +- construct/core.py | 118 +++++++++++++++++++++---- construct/macros.py | 13 +-- construct/protocols/application/dns.py | 14 +-- construct/tests/test_core.py | 6 +- construct/tests/unit.py | 13 ++- 6 files changed, 132 insertions(+), 38 deletions(-) diff --git a/construct/__init__.py b/construct/__init__.py index feac2960b..5f772cd28 100644 --- a/construct/__init__.py +++ b/construct/__init__.py @@ -84,10 +84,10 @@ def wrapper(*args, **kwargs): 'AdaptationError', 'Adapter', 'Alias', 'Aligned', 'AlignedStruct', 'Anchor', 'Array', 'ArrayError', 'BFloat32', 'BFloat64', 'Bit', 'BitField', 'BitIntegerAdapter', 'BitIntegerError', 'BitStruct', 'Bits', 'Bitwise', - 'Buffered', 'Byte', 'Bytes', 'CString', 'CStringAdapter', 'Const', + 'Buffered', 'Byte', 'Bytes', 'CString', 'CStringAdapter', 'Const', 'Constant', 'ConstAdapter', 'ConstError', 'Construct', 'ConstructError', 'Container', 'Debugger', 'Embed', 'Embedded', 'EmbeddedBitStruct', 'Enum', 'ExprAdapter', - 'Field', 'FieldError', 'Flag', 'FlagsAdapter', 'FlagsContainer', + 'Equals', 'Field', 'FieldError', 'Flag', 'FlagsAdapter', 'FlagsContainer', 'FlagsEnum', 'FormatField', 'GreedyRange', 'GreedyRepeater', 'HexDumpAdapter', 'If', 'IfThenElse', 'IndexingAdapter', 'LFloat32', 'LFloat64', 'LazyBound', 'LengthValueAdapter', 'ListContainer', @@ -106,5 +106,5 @@ def wrapper(*args, **kwargs): 'SymmetricMapping', 'Terminator', 'TerminatorError', 'Tunnel', 'TunnelAdapter', 'UBInt16', 'UBInt32', 'UBInt64', 'UBInt8', 'ULInt16', 'ULInt32', 'ULInt64', 'ULInt8', 'UNInt16', 'UNInt32', 'UNInt64', 'UNInt8', - 'Union', 'ValidationError', 'Validator', 'Value', "Magic", + 'Union', 'ValidationError', 'Validator', 'Value', 'ValueOf', "Magic", ] diff --git a/construct/core.py b/construct/core.py index cd8dc5beb..e3985337d 100644 --- a/construct/core.py +++ b/construct/core.py @@ -387,12 +387,19 @@ def __init__(self, name, lengthfunc): Construct.__init__(self, name) self.lengthfunc = lengthfunc self._set_flag(self.FLAG_DYNAMIC) + + def length(self, context): + try: + return self.lengthfunc.value(context) + except AttributeError: + return self.lengthfunc(context) + def _parse(self, stream, context): - return _read_stream(stream, self.lengthfunc(context)) + return _read_stream(stream, self.length(context)) def _build(self, obj, stream, context): - _write_stream(stream, self.lengthfunc(context), obj) + _write_stream(stream, self.length(context), obj) def _sizeof(self, context): - return self.lengthfunc(context) + return self.length(context) #=============================================================================== @@ -418,10 +425,17 @@ def __init__(self, countfunc, subcon): self.countfunc = countfunc self._clear_flag(self.FLAG_COPY_CONTEXT) self._set_flag(self.FLAG_DYNAMIC) + + def count(self, context): + try: + return self.countfunc.value(context) + except AttributeError: + return self.countfunc(context) + def _parse(self, stream, context): obj = ListContainer() c = 0 - count = self.countfunc(context) + count = self.count(context) try: if self.subcon.conflags & self.FLAG_COPY_CONTEXT: while c < count: @@ -435,7 +449,7 @@ def _parse(self, stream, context): raise ArrayError("expected %d, found %d" % (count, c), ex) return obj def _build(self, obj, stream, context): - count = self.countfunc(context) + count = self.count(context) if len(obj) != count: raise ArrayError("expected %d, found %d" % (count, len(obj))) if self.subcon.conflags & self.FLAG_COPY_CONTEXT: @@ -445,7 +459,7 @@ def _build(self, obj, stream, context): for subobj in obj: self.subcon._build(subobj, stream, context) def _sizeof(self, context): - return self.subcon._sizeof(context) * self.countfunc(context) + return self.subcon._sizeof(context) * self.count(context) class Range(Subconstruct): """ @@ -561,6 +575,13 @@ def __init__(self, predicate, subcon): self.predicate = predicate self._clear_flag(self.FLAG_COPY_CONTEXT) self._set_flag(self.FLAG_DYNAMIC) + + def check(self, subobj, context): + try: + return self.predicate.check(subobj, context) + except AttributeError: + return self.predicate(subobj, context) + def _parse(self, stream, context): obj = [] try: @@ -568,13 +589,13 @@ def _parse(self, stream, context): while True: subobj = self.subcon._parse(stream, context.__copy__()) obj.append(subobj) - if self.predicate(subobj, context): + if self.check(subobj, context): break else: while True: subobj = self.subcon._parse(stream, context) obj.append(subobj) - if self.predicate(subobj, context): + if self.check(subobj, context): break except ConstructError, ex: raise ArrayError("missing terminator", ex) @@ -584,13 +605,13 @@ def _build(self, obj, stream, context): if self.subcon.conflags & self.FLAG_COPY_CONTEXT: for subobj in obj: self.subcon._build(subobj, stream, context.__copy__()) - if self.predicate(subobj, context): + if self.check(subobj, context): terminated = True break else: for subobj in obj: self.subcon._build(subobj, stream, context.__copy__()) - if self.predicate(subobj, context): + if self.check(subobj, context): terminated = True break if not terminated: @@ -821,8 +842,15 @@ def __init__(self, name, keyfunc, cases, default = NoDefault, self.include_key = include_key self._inherit_flags(*cases.values()) self._set_flag(self.FLAG_DYNAMIC) + + def getkey(self, context): + try: + return self.keyfunc.value(context) + except AttributeError: + return self.keyfunc(context) + def _parse(self, stream, context): - key = self.keyfunc(context) + key = self.getkey(context) obj = self.cases.get(key, self.default)._parse(stream, context) if self.include_key: return key, obj @@ -832,11 +860,11 @@ def _build(self, obj, stream, context): if self.include_key: key, obj = obj else: - key = self.keyfunc(context) + key = self.getkey(context) case = self.cases.get(key, self.default) case._build(obj, stream, context) def _sizeof(self, context): - case = self.cases.get(self.keyfunc(context), self.default) + case = self.cases.get(self.getkey(context), self.default) return case._sizeof(context) class Select(Construct): @@ -1200,10 +1228,19 @@ def __init__(self, name, func): Construct.__init__(self, name) self.func = func self._set_flag(self.FLAG_DYNAMIC) + + def getval(self, context): + try: + return self.func.value(context) + except AttributeError: + return self.func(context) + def _parse(self, stream, context): - return self.func(context) + return self.getval(context) + def _build(self, obj, stream, context): - context[self.name] = self.func(context) + context[self.name] = self.getval(context) + def _sizeof(self, context): return 0 @@ -1319,3 +1356,54 @@ def _build(self, obj, stream, context): def _sizeof(self, context): return 0 Terminator = Terminator(None) + + +class ValueOf(object): + + def __init__(self, name): + self.name = name + + def value(self, context): + print context, self.name + return context[self.name] + + +class Equals(object): + def __init__(self, value, othervalue=None): + self.v = value + self.otherv = None + + def check(self, obj, context): + return self.v == obj + + def value(self, context): + try: + left = self.v.value(context) + except AttributeError: + left = self.value + try: + right = self.otherv.value(context) + except AttributeError: + right = self.otherv + return left == right + + +class Constant(object): + def __init__(self, value): + self.v = value + + def check(self, obj, context): + return self.v + + def value(self, context): + return self.v + +class Boolean(object): + def __init__(self, pred): + self.pred = pred + + def value(self, context): + try: + return bool(self.pred.value(context)) + except AttributeError: + return bool(self.pred(context)) diff --git a/construct/macros.py b/construct/macros.py index cfdcf465c..52ab41287 100644 --- a/construct/macros.py +++ b/construct/macros.py @@ -1,7 +1,7 @@ from construct.lib import BitStreamReader, BitStreamWriter, encode_bin, decode_bin from construct.core import (Struct, MetaField, StaticField, FormatField, OnDemand, Pointer, Switch, Value, RepeatUntil, MetaArray, Sequence, Range, - Select, Pass, SizeofError, Buffered, Restream, Reconfig) + Select, Pass, SizeofError, Buffered, Restream, Reconfig, Boolean) from construct.adapters import (BitIntegerAdapter, PaddingAdapter, ConstAdapter, CStringAdapter, LengthValueAdapter, IndexingAdapter, PaddedStringAdapter, FlagsAdapter, StringAdapter, MappingAdapter) @@ -19,10 +19,11 @@ def Field(name, length): (StaticField), or a function that takes the context as an argument and returns the length (MetaField) """ - if callable(length): + try: + return StaticField(name, int(length)) + except TypeError: return MetaField(name, length) - else: - return StaticField(name, length) + def BitField(name, length, swapped = False, signed = False, bytesize = 8): """ @@ -240,7 +241,7 @@ def Array(count, subcon): construct.core.RangeError: expected 4..4, found 5 """ - if callable(count): + if callable(count) or hasattr(count, 'value'): #blegh con = MetaArray(count, subcon) else: con = MetaArray(lambda ctx: count, subcon) @@ -591,7 +592,7 @@ def IfThenElse(name, predicate, then_subcon, else_subcon): * then_subcon - the subcon that will be used if the predicate returns True * else_subcon - the subcon that will be used if the predicate returns False """ - return Switch(name, lambda ctx: bool(predicate(ctx)), + return Switch(name, Boolean(predicate), { True : then_subcon, False : else_subcon, diff --git a/construct/protocols/application/dns.py b/construct/protocols/application/dns.py index e98ac098b..a80ec376d 100644 --- a/construct/protocols/application/dns.py +++ b/construct/protocols/application/dns.py @@ -38,7 +38,7 @@ def _decode(self, obj, context): query_record = Struct("query_record", DnsStringAdapter( - RepeatUntil(lambda obj, ctx: obj == "", + RepeatUntil(Equals(""), PascalString("name") ) ), @@ -46,7 +46,7 @@ def _decode(self, obj, context): dns_record_class, ) -rdata = Field("rdata", lambda ctx: ctx.rdata_length) +rdata = Field("rdata", ValueOf("rdata_length")) resource_record = Struct("resource_record", CString("name", terminators = "\xc0\x00"), @@ -55,7 +55,7 @@ def _decode(self, obj, context): dns_record_class, UBInt32("ttl"), UBInt16("rdata_length"), - IfThenElse("data", lambda ctx: ctx.type == "IPv4", + IfThenElse("data", Equals(ValueOf("type"), "IPv4"), IpAddressAdapter(rdata), rdata ) @@ -100,16 +100,16 @@ def _decode(self, obj, context): UBInt16("answer_count"), UBInt16("authority_count"), UBInt16("additional_count"), - Array(lambda ctx: ctx.question_count, + Array(ValueOf("question_count"), Rename("questions", query_record), ), Rename("answers", - Array(lambda ctx: ctx.answer_count, resource_record) + Array(ValueOf("answer_count"), resource_record) ), Rename("authorities", - Array(lambda ctx: ctx.authority_count, resource_record) + Array(ValueOf("authority_count"), resource_record) ), - Array(lambda ctx: ctx.additional_count, + Array(ValueOf("additional_count"), Rename("additionals", resource_record), ), ) diff --git a/construct/tests/test_core.py b/construct/tests/test_core.py index e857d86fd..c46138e79 100644 --- a/construct/tests/test_core.py +++ b/construct/tests/test_core.py @@ -1,7 +1,7 @@ import unittest from construct import Struct, MetaField, StaticField, FormatField -from construct import Container, Byte +from construct import Container, Byte, ValueOf from construct import FieldError, SizeofError class TestStaticField(unittest.TestCase): @@ -73,10 +73,10 @@ def test_build_too_short(self): def test_sizeof(self): self.assertEqual(self.mf.sizeof(), 3) -class TestMetaFieldStruct(unittest.TestCase): +class TestParamMetaFieldStruct(unittest.TestCase): def setUp(self): - self.mf = MetaField("data", lambda context: context["length"]) + self.mf = MetaField("data", ValueOf("length")) self.s = Struct("foo", Byte("length"), self.mf) def test_trivial(self): diff --git a/construct/tests/unit.py b/construct/tests/unit.py index c046035c6..06ff8188c 100644 --- a/construct/tests/unit.py +++ b/construct/tests/unit.py @@ -15,6 +15,7 @@ # constructs # [MetaArray(lambda ctx: 3, UBInt8("metaarray")).parse, "\x01\x02\x03", [1,2,3], None], + [MetaArray(lambda ctx: 3, UBInt8("metaarray")).parse, "\x01\x02\x03", [1,2,3], None], [MetaArray(lambda ctx: 3, UBInt8("metaarray")).parse, "\x01\x02", None, ArrayError], [MetaArray(lambda ctx: 3, UBInt8("metaarray")).build, [1,2,3], "\x01\x02\x03", None], [MetaArray(lambda ctx: 3, UBInt8("metaarray")).build, [1,2], None, ArrayError], @@ -31,9 +32,7 @@ [RepeatUntil(lambda obj, ctx: obj == 9, UBInt8("repeatuntil")).parse, "\x02\x03\x09", [2,3,9], None], [RepeatUntil(lambda obj, ctx: obj == 9, UBInt8("repeatuntil")).parse, "\x02\x03\x08", None, ArrayError], - [RepeatUntil(lambda obj, ctx: obj == 9, UBInt8("repeatuntil")).build, [2,3,9], "\x02\x03\x09", None], - [RepeatUntil(lambda obj, ctx: obj == 9, UBInt8("repeatuntil")).build, [2,3,8], None, ArrayError], - + [RepeatUntil(Equals(9), UBInt8("repeatuntil")).build, [2,3,9], "\x02\x03\x09", None], [Struct("struct", UBInt8("a"), UBInt16("b")).parse, "\x01\x00\x02", Container(a=1,b=2), None], [Struct("struct", UBInt8("a"), UBInt16("b"), Struct("foo", UBInt8("c"), UBInt8("d"))).parse, "\x01\x00\x02\x03\x04", Container(a=1,b=2,foo=Container(c=3,d=4)), None], [Struct("struct", UBInt8("a"), UBInt16("b"), Embedded(Struct("foo", UBInt8("c"), UBInt8("d")))).parse, "\x01\x00\x02\x03\x04", Container(a=1,b=2,c=3,d=4), None], @@ -49,6 +48,7 @@ [Sequence("sequence", UBInt8("a"), UBInt16("b"), Embedded(Sequence("foo", UBInt8("c"), UBInt8("d")))).build, [1,2,3,4], "\x01\x00\x02\x03\x04", None], [Switch("switch", lambda ctx: 5, {1:UBInt8("x"), 5:UBInt16("y")}).parse, "\x00\x02", 2, None], + [Switch("switch", Constant(5), {1:UBInt8("x"), 5:UBInt16("y")}).parse, "\x00\x02", 2, None], [Switch("switch", lambda ctx: 6, {1:UBInt8("x"), 5:UBInt16("y")}).parse, "\x00\x02", None, SwitchError], [Switch("switch", lambda ctx: 6, {1:UBInt8("x"), 5:UBInt16("y")}, default = UBInt8("x")).parse, "\x00\x02", 0, None], [Switch("switch", lambda ctx: 5, {1:UBInt8("x"), 5:UBInt16("y")}, include_key = True).parse, "\x00\x02", (5, 2), None], @@ -174,6 +174,8 @@ [LengthValueAdapter(Sequence("lengthvalueadapter", UBInt8("length"), Field("value", lambda ctx: ctx.length))).parse, "\x05abcde", "abcde", None], + [LengthValueAdapter(Sequence("lengthvalueadapter", UBInt8("length"), Field("value", ValueOf('length')))).parse, + "\x05abcde", "abcde", None], [LengthValueAdapter(Sequence("lengthvalueadapter", UBInt8("length"), Field("value", lambda ctx: ctx.length))).build, "abcde", "\x05abcde", None], @@ -241,6 +243,7 @@ [Bitwise(Field("bitwise", 8)).parse, "\xff", "\x01" * 8, None], [Bitwise(Field("bitwise", lambda ctx: 8)).parse, "\xff", "\x01" * 8, None], + [Bitwise(Field("bitwise", Constant(8))).parse, "\xff", "\x01" * 8, None], [Bitwise(Field("bitwise", 8)).build, "\x01" * 8, "\xff", None], [Bitwise(Field("bitwise", lambda ctx: 8)).build, "\x01" * 8, "\xff", None], @@ -296,7 +299,9 @@ 1, "\x01", None], [IfThenElse("ifthenelse", lambda ctx: False, UBInt8("then"), UBInt16("else")).build, 1, "\x00\x01", None], - + [IfThenElse("ifthenelse", Constant(False), UBInt8("then"), UBInt16("else")).build, + 1, "\x00\x01", None], + [IfThenElse("ifthenelse", Constant(True), UBInt8("then"), UBInt16("else")).build, 1, "\x01", None], [Magic("MZ").parse, "MZ", "MZ", None], [Magic("MZ").parse, "ELF", None, ConstError], [Magic("MZ").build, None, "MZ", None],