Skip to content
This repository was archived by the owner on Sep 1, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions construct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def wrapper(*args, **kwargs):
'AdaptationError', 'Adapter', 'Alias', 'Aligned', 'AlignedStruct',
'Anchor', 'Array', 'ArrayError', 'BFloat32', 'BFloat64', 'Bit', 'BitField',
'BitIntegerAdapter', 'BitIntegerError', 'BitStruct', 'Bits', 'Bitwise',
'Buffered', 'Byte', 'Bytes', 'CString', 'CStringAdapter', 'Const',
'Buffered', 'Byte', 'Bytes', 'CString', 'CStringAdapter', 'Const', 'Constant',
'ConstAdapter', 'ConstError', 'Construct', 'ConstructError', 'Container',
'Debugger', 'Embed', 'Embedded', 'EmbeddedBitStruct', 'Enum', 'ExprAdapter',
'Field', 'FieldError', 'Flag', 'FlagsAdapter', 'FlagsContainer',
'Equals', 'Field', 'FieldError', 'Flag', 'FlagsAdapter', 'FlagsContainer',
'FlagsEnum', 'FormatField', 'GreedyRange', 'GreedyRepeater',
'HexDumpAdapter', 'If', 'IfThenElse', 'IndexingAdapter', 'LFloat32',
'LFloat64', 'LazyBound', 'LengthValueAdapter', 'ListContainer',
Expand All @@ -106,5 +106,5 @@ def wrapper(*args, **kwargs):
'SymmetricMapping', 'Terminator', 'TerminatorError', 'Tunnel',
'TunnelAdapter', 'UBInt16', 'UBInt32', 'UBInt64', 'UBInt8', 'ULInt16',
'ULInt32', 'ULInt64', 'ULInt8', 'UNInt16', 'UNInt32', 'UNInt64', 'UNInt8',
'Union', 'ValidationError', 'Validator', 'Value', "Magic",
'Union', 'ValidationError', 'Validator', 'Value', 'ValueOf', "Magic",
]
118 changes: 103 additions & 15 deletions construct/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,19 @@ def __init__(self, name, lengthfunc):
Construct.__init__(self, name)
self.lengthfunc = lengthfunc
self._set_flag(self.FLAG_DYNAMIC)

def length(self, context):
try:
return self.lengthfunc.value(context)
except AttributeError:
return self.lengthfunc(context)

def _parse(self, stream, context):
return _read_stream(stream, self.lengthfunc(context))
return _read_stream(stream, self.length(context))
def _build(self, obj, stream, context):
_write_stream(stream, self.lengthfunc(context), obj)
_write_stream(stream, self.length(context), obj)
def _sizeof(self, context):
return self.lengthfunc(context)
return self.length(context)


#===============================================================================
Expand All @@ -418,10 +425,17 @@ def __init__(self, countfunc, subcon):
self.countfunc = countfunc
self._clear_flag(self.FLAG_COPY_CONTEXT)
self._set_flag(self.FLAG_DYNAMIC)

def count(self, context):
try:
return self.countfunc.value(context)
except AttributeError:
return self.countfunc(context)

def _parse(self, stream, context):
obj = ListContainer()
c = 0
count = self.countfunc(context)
count = self.count(context)
try:
if self.subcon.conflags & self.FLAG_COPY_CONTEXT:
while c < count:
Expand All @@ -435,7 +449,7 @@ def _parse(self, stream, context):
raise ArrayError("expected %d, found %d" % (count, c), ex)
return obj
def _build(self, obj, stream, context):
count = self.countfunc(context)
count = self.count(context)
if len(obj) != count:
raise ArrayError("expected %d, found %d" % (count, len(obj)))
if self.subcon.conflags & self.FLAG_COPY_CONTEXT:
Expand All @@ -445,7 +459,7 @@ def _build(self, obj, stream, context):
for subobj in obj:
self.subcon._build(subobj, stream, context)
def _sizeof(self, context):
return self.subcon._sizeof(context) * self.countfunc(context)
return self.subcon._sizeof(context) * self.count(context)

class Range(Subconstruct):
"""
Expand Down Expand Up @@ -561,20 +575,27 @@ def __init__(self, predicate, subcon):
self.predicate = predicate
self._clear_flag(self.FLAG_COPY_CONTEXT)
self._set_flag(self.FLAG_DYNAMIC)

def check(self, subobj, context):
try:
return self.predicate.check(subobj, context)
except AttributeError:
return self.predicate(subobj, context)

def _parse(self, stream, context):
obj = []
try:
if self.subcon.conflags & self.FLAG_COPY_CONTEXT:
while True:
subobj = self.subcon._parse(stream, context.__copy__())
obj.append(subobj)
if self.predicate(subobj, context):
if self.check(subobj, context):
break
else:
while True:
subobj = self.subcon._parse(stream, context)
obj.append(subobj)
if self.predicate(subobj, context):
if self.check(subobj, context):
break
except ConstructError, ex:
raise ArrayError("missing terminator", ex)
Expand All @@ -584,13 +605,13 @@ def _build(self, obj, stream, context):
if self.subcon.conflags & self.FLAG_COPY_CONTEXT:
for subobj in obj:
self.subcon._build(subobj, stream, context.__copy__())
if self.predicate(subobj, context):
if self.check(subobj, context):
terminated = True
break
else:
for subobj in obj:
self.subcon._build(subobj, stream, context.__copy__())
if self.predicate(subobj, context):
if self.check(subobj, context):
terminated = True
break
if not terminated:
Expand Down Expand Up @@ -821,8 +842,15 @@ def __init__(self, name, keyfunc, cases, default = NoDefault,
self.include_key = include_key
self._inherit_flags(*cases.values())
self._set_flag(self.FLAG_DYNAMIC)

def getkey(self, context):
try:
return self.keyfunc.value(context)
except AttributeError:
return self.keyfunc(context)

def _parse(self, stream, context):
key = self.keyfunc(context)
key = self.getkey(context)
obj = self.cases.get(key, self.default)._parse(stream, context)
if self.include_key:
return key, obj
Expand All @@ -832,11 +860,11 @@ def _build(self, obj, stream, context):
if self.include_key:
key, obj = obj
else:
key = self.keyfunc(context)
key = self.getkey(context)
case = self.cases.get(key, self.default)
case._build(obj, stream, context)
def _sizeof(self, context):
case = self.cases.get(self.keyfunc(context), self.default)
case = self.cases.get(self.getkey(context), self.default)
return case._sizeof(context)

class Select(Construct):
Expand Down Expand Up @@ -1200,10 +1228,19 @@ def __init__(self, name, func):
Construct.__init__(self, name)
self.func = func
self._set_flag(self.FLAG_DYNAMIC)

def getval(self, context):
try:
return self.func.value(context)
except AttributeError:
return self.func(context)

def _parse(self, stream, context):
return self.func(context)
return self.getval(context)

def _build(self, obj, stream, context):
context[self.name] = self.func(context)
context[self.name] = self.getval(context)

def _sizeof(self, context):
return 0

Expand Down Expand Up @@ -1319,3 +1356,54 @@ def _build(self, obj, stream, context):
def _sizeof(self, context):
return 0
Terminator = Terminator(None)


class ValueOf(object):

def __init__(self, name):
self.name = name

def value(self, context):
print context, self.name
return context[self.name]


class Equals(object):
def __init__(self, value, othervalue=None):
self.v = value
self.otherv = None

def check(self, obj, context):
return self.v == obj

def value(self, context):
try:
left = self.v.value(context)
except AttributeError:
left = self.value
try:
right = self.otherv.value(context)
except AttributeError:
right = self.otherv
return left == right


class Constant(object):
def __init__(self, value):
self.v = value

def check(self, obj, context):
return self.v

def value(self, context):
return self.v

class Boolean(object):
def __init__(self, pred):
self.pred = pred

def value(self, context):
try:
return bool(self.pred.value(context))
except AttributeError:
return bool(self.pred(context))
13 changes: 7 additions & 6 deletions construct/macros.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from construct.lib import BitStreamReader, BitStreamWriter, encode_bin, decode_bin
from construct.core import (Struct, MetaField, StaticField, FormatField,
OnDemand, Pointer, Switch, Value, RepeatUntil, MetaArray, Sequence, Range,
Select, Pass, SizeofError, Buffered, Restream, Reconfig)
Select, Pass, SizeofError, Buffered, Restream, Reconfig, Boolean)
from construct.adapters import (BitIntegerAdapter, PaddingAdapter,
ConstAdapter, CStringAdapter, LengthValueAdapter, IndexingAdapter,
PaddedStringAdapter, FlagsAdapter, StringAdapter, MappingAdapter)
Expand All @@ -19,10 +19,11 @@ def Field(name, length):
(StaticField), or a function that takes the context as an argument and
returns the length (MetaField)
"""
if callable(length):
try:
return StaticField(name, int(length))
except TypeError:
return MetaField(name, length)
else:
return StaticField(name, length)


def BitField(name, length, swapped = False, signed = False, bytesize = 8):
"""
Expand Down Expand Up @@ -240,7 +241,7 @@ def Array(count, subcon):
construct.core.RangeError: expected 4..4, found 5
"""

if callable(count):
if callable(count) or hasattr(count, 'value'): #blegh
con = MetaArray(count, subcon)
else:
con = MetaArray(lambda ctx: count, subcon)
Expand Down Expand Up @@ -591,7 +592,7 @@ def IfThenElse(name, predicate, then_subcon, else_subcon):
* then_subcon - the subcon that will be used if the predicate returns True
* else_subcon - the subcon that will be used if the predicate returns False
"""
return Switch(name, lambda ctx: bool(predicate(ctx)),
return Switch(name, Boolean(predicate),
{
True : then_subcon,
False : else_subcon,
Expand Down
14 changes: 7 additions & 7 deletions construct/protocols/application/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def _decode(self, obj, context):

query_record = Struct("query_record",
DnsStringAdapter(
RepeatUntil(lambda obj, ctx: obj == "",
RepeatUntil(Equals(""),
PascalString("name")
)
),
dns_record_type,
dns_record_class,
)

rdata = Field("rdata", lambda ctx: ctx.rdata_length)
rdata = Field("rdata", ValueOf("rdata_length"))

resource_record = Struct("resource_record",
CString("name", terminators = "\xc0\x00"),
Expand All @@ -55,7 +55,7 @@ def _decode(self, obj, context):
dns_record_class,
UBInt32("ttl"),
UBInt16("rdata_length"),
IfThenElse("data", lambda ctx: ctx.type == "IPv4",
IfThenElse("data", Equals(ValueOf("type"), "IPv4"),
IpAddressAdapter(rdata),
rdata
)
Expand Down Expand Up @@ -100,16 +100,16 @@ def _decode(self, obj, context):
UBInt16("answer_count"),
UBInt16("authority_count"),
UBInt16("additional_count"),
Array(lambda ctx: ctx.question_count,
Array(ValueOf("question_count"),
Rename("questions", query_record),
),
Rename("answers",
Array(lambda ctx: ctx.answer_count, resource_record)
Array(ValueOf("answer_count"), resource_record)
),
Rename("authorities",
Array(lambda ctx: ctx.authority_count, resource_record)
Array(ValueOf("authority_count"), resource_record)
),
Array(lambda ctx: ctx.additional_count,
Array(ValueOf("additional_count"),
Rename("additionals", resource_record),
),
)
Expand Down
6 changes: 3 additions & 3 deletions construct/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from construct import Struct, MetaField, StaticField, FormatField
from construct import Container, Byte
from construct import Container, Byte, ValueOf
from construct import FieldError, SizeofError

class TestStaticField(unittest.TestCase):
Expand Down Expand Up @@ -73,10 +73,10 @@ def test_build_too_short(self):
def test_sizeof(self):
self.assertEqual(self.mf.sizeof(), 3)

class TestMetaFieldStruct(unittest.TestCase):
class TestParamMetaFieldStruct(unittest.TestCase):

def setUp(self):
self.mf = MetaField("data", lambda context: context["length"])
self.mf = MetaField("data", ValueOf("length"))
self.s = Struct("foo", Byte("length"), self.mf)

def test_trivial(self):
Expand Down
Loading