diff --git a/README.md b/README.md index eb85940..7ae1bf4 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ # SMDA SMDA is a minimalist recursive disassembler library that is optimized for accurate Control Flow Graph (CFG) recovery from memory dumps. -It is based on [Capstone](http://www.capstone-engine.org/) and currently supports x86/x64 Intel machine code. -As input, arbitrary memory dumps (ideally with known base address) can be processed. +It is based on [Capstone](http://www.capstone-engine.org/) and currently supports x86/x64 Intel machine code, experimental CIL (.NET) disassembly, and Dalvik bytecode from raw DEX files. +As input, arbitrary memory dumps (ideally with known base address) can be processed, and raw DEX files can be analyzed directly. The output is a collection of functions, basic blocks, and instructions with their respective edges between blocks and functions (in/out). Optionally, references to the Windows API can be inferred by using the ApiScout method. @@ -48,6 +48,8 @@ There is also a demo script: * analyze.py -- example usage: perform disassembly on a file or memory dump and optionally store results in JSON to a given output path. +For Dalvik, the current scope is raw single-DEX inputs. APK, multi-dex container handling, and ODEX/VDEX/CDEX runtime-artifact analysis are not yet first-class workflows in SMDA. + The code should be fully compatible with Python 3.8+. Further explanation on the innerworkings follow in separate publications but will be referenced here. diff --git a/analyze.py b/analyze.py index a05cfb3..14e0b50 100644 --- a/analyze.py +++ b/analyze.py @@ -4,12 +4,14 @@ import os import re import sys +import textwrap from smda.Disassembler import Disassembler from smda.SmdaConfig import SmdaConfig +from smda.utility.DexFileLoader import DexFileLoader -def parseBaseAddrFromArgs(args): +def parseBaseAddrFromArgs(args, silent=False): if args.base_addr: parsed_base_addr = int(args.base_addr, 16) if args.base_addr.startswith("0x") else int(args.base_addr) logging.info("using provided base address: 0x%08x", parsed_base_addr) @@ -18,24 +20,66 @@ def parseBaseAddrFromArgs(args): baddr_match = re.search(re.compile("_0x(?P[0-9a-fA-F]{8,16})"), args.input_path) if baddr_match: parsed_base_addr = int(baddr_match.group("base_addr"), 16) - logging.info( - "Parsed base address from file name: 0x%08x", - parsed_base_addr, - ) + logging.info("Parsed base address from file name: 0x%08x", parsed_base_addr) return parsed_base_addr - logging.warning("No base address recognized, using 0.") + if not silent: + logging.warning("No base address recognized, using 0.") return 0 -def parseOepFromArgs(args): +def parseOepFromArgs(args, silent=False): if args.oep and args.oep != "": parsed_oep = int(args.oep, 16) if args.oep.startswith("0x") else int(args.oep) logging.info("using provided OEP(RVA): 0x%08x", parsed_oep) return parsed_oep - logging.warning("No OEP recognized, skipping.") + if not silent: + logging.warning("No OEP recognized, skipping.") return None +def _printDalvikSummary(report, output_path, input_filename): + """Structured one-screen summary printed to stdout after Dalvik/DEX analysis.""" + size_bytes = report.binary_size or 0 + size_str = f"{size_bytes / 1024 / 1024:.1f} MB" if size_bytes >= 1024 * 1024 else f"{size_bytes / 1024:.1f} KB" + + dex_version = report.version if report.version else "?" + bitness_str = f".{report.bitness}bit" if report.bitness else "" + stats = report.statistics + + # Aggregate heuristic tags and string-ref count across all functions + heuristic_counts = {} + string_ref_total = 0 + for fn in report.getFunctions(): + for tag in (fn.architecture_metadata or {}).get("heuristics", []): + heuristic_counts[tag] = heuristic_counts.get(tag, 0) + 1 + string_ref_total += len(fn.stringrefs or {}) + + print(f"[*] File: {input_filename} ({size_str})") + print(f"[*] Architecture: {report.architecture}{bitness_str}") + print(f"[*] Format: Dalvik DEX v{dex_version}") + print(f"[*] Time: {report.execution_time:.3f}s") + print(f"[*] Functions: {stats.num_functions:,}") + print(f"[*] CFG: {stats.num_basic_blocks:,} blocks / {stats.num_instructions:,} instructions") + print(f"[*] Refs: api={stats.num_api_calls:,} strings={string_ref_total:,}") + + if heuristic_counts: + tags = sorted(heuristic_counts.items(), key=lambda kv: -kv[1]) + joined = " ".join(f"{k}={v}" for k, v in tags) + lines = textwrap.wrap(joined, width=52, break_long_words=False, break_on_hyphens=False) or [""] + print(f"[!] Heuristics: {lines[0]}") + for line in lines[1:]: + print(f" {line}") + + if output_path and os.path.isdir(output_path): + print(f"[+] Saved: {os.path.join(output_path, input_filename + '.smda')}") + + +def _getInteractiveStream(stream): + if hasattr(stream, "reconfigure"): + stream.reconfigure(errors="backslashreplace") + return stream + + def readFileContent(file_path): file_content = b"" with open(file_path, "rb") as fin: @@ -66,7 +110,7 @@ def readFileContent(file_path): "--architecture", type=str, default="", - help="Use the disassembler for the following architecture if available (default:auto, options: [intel, cil]).", + help="Use the disassembler for the following architecture if available (default:auto, options: [intel, cil, dalvik]).", ) PARSER.add_argument( "-a", @@ -120,11 +164,18 @@ def readFileContent(file_path): # optionally create and set up a config, e.g. when using ApiScout profiles for WinAPI import usage discovery config = SmdaConfig() - if ARGS.verbose: - config.LOG_LEVEL = logging.DEBUG if ARGS.strings: config.WITH_STRINGS = True - logging.basicConfig(level=config.LOG_LEVEL, format=config.LOG_FORMAT) + if ARGS.verbose: + config.LOG_LEVEL = logging.DEBUG + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s.%(msecs)03d %(levelname).1s %(name)s: %(message)s", + datefmt="%H:%M:%S", + stream=_getInteractiveStream(sys.stdout), + ) + else: + logging.basicConfig(level=config.LOG_LEVEL, format=config.LOG_FORMAT) SMDA_REPORT = None INPUT_FILENAME = "" BITNESS = ARGS.bitness if (ARGS.bitness in [32, 64]) else None @@ -136,15 +187,24 @@ def readFileContent(file_path): SMDA_REPORT = DISASSEMBLER.disassembleFile(ARGS.input_path, pdb_path=ARGS.pdb_path) else: BUFFER = readFileContent(ARGS.input_path) - BASE_ADDR = parseBaseAddrFromArgs(ARGS) - OEP = parseOepFromArgs(ARGS) + treat_as_dalvik = ARGS.architecture in {"", "dalvik"} and DexFileLoader.isCompatible(BUFFER) + if treat_as_dalvik: + BASE_ADDR = DexFileLoader.getBaseAddress(BUFFER) + OEP = None + else: + BASE_ADDR = parseBaseAddrFromArgs(ARGS) + OEP = parseOepFromArgs(ARGS) config.API_COLLECTION_FILES = { "win_7": os.sep.join([config.PROJECT_ROOT, "data", "apiscout_win7_prof-n_sp1.json"]) } DISASSEMBLER = Disassembler(config, backend=ARGS.architecture) SMDA_REPORT = DISASSEMBLER.disassembleBuffer(BUFFER, BASE_ADDR, BITNESS, oep=OEP) SMDA_REPORT.filename = os.path.basename(ARGS.input_path) - print(SMDA_REPORT) - if SMDA_REPORT and os.path.isdir(ARGS.output_path): - with open(ARGS.output_path + os.sep + INPUT_FILENAME + ".smda", "w") as fout: + if SMDA_REPORT.architecture == "dalvik": + _printDalvikSummary(SMDA_REPORT, ARGS.output_path, INPUT_FILENAME) + else: + print(SMDA_REPORT) + if SMDA_REPORT and ARGS.output_path and os.path.isdir(ARGS.output_path): + output_file = os.path.join(ARGS.output_path, INPUT_FILENAME + ".smda") + with open(output_file, "w") as fout: json.dump(SMDA_REPORT.toDict(), fout, indent=1, sort_keys=True) diff --git a/smda/Disassembler.py b/smda/Disassembler.py index c2b6641..6ce1253 100644 --- a/smda/Disassembler.py +++ b/smda/Disassembler.py @@ -7,9 +7,11 @@ from smda.common.BinaryInfo import BinaryInfo from smda.common.labelprovider.GoLabelProvider import GoSymbolProvider from smda.common.SmdaReport import SmdaReport +from smda.dalvik.DalvikDisassembler import DalvikDisassembler from smda.ida.IdaExporter import IdaExporter from smda.intel.IntelDisassembler import IntelDisassembler from smda.SmdaConfig import SmdaConfig +from smda.utility.DexFileLoader import DexFileLoader from smda.utility.FileLoader import FileLoader from smda.utility.MemoryFileLoader import MemoryFileLoader from smda.utility.StringExtractor import extract_strings @@ -27,10 +29,13 @@ def __init__(self, config=None, backend=None): self.disassembler = IntelDisassembler(self.config) elif backend == "cil": self.disassembler = CilDisassembler(self.config) + elif backend == "dalvik": + self.disassembler = DalvikDisassembler(self.config) elif backend == "IDA": self.disassembler = IdaExporter(self.config) self._start_time = None self._timeout = 0 + self._last_timeout_log_second = -1 # cache the last DisassemblyResult self.disassembly = None @@ -41,6 +46,8 @@ def initDisassembler(self, architecture="intel"): self.disassembler = IntelDisassembler(self.config) elif architecture == "cil": self.disassembler = CilDisassembler(self.config) + elif architecture == "dalvik": + self.disassembler = DalvikDisassembler(self.config) def _getDurationInSeconds(self, start_ts, end_ts): return (end_ts - start_ts).seconds + ((end_ts - start_ts).microseconds / 1000000.0) @@ -49,12 +56,33 @@ def _callbackAnalysisTimeout(self): if not self._timeout: return False time_diff = datetime.datetime.now(datetime.timezone.utc) - self._start_time - LOGGER.debug("Current analysis callback time %s", (time_diff)) - return time_diff.seconds >= self._timeout + elapsed_seconds = int(time_diff.total_seconds()) + if elapsed_seconds >= self._timeout: + LOGGER.debug("Current analysis callback time %s", time_diff) + return True + # Log on 30s bucket transitions (not exact-second boundaries) so the message + # still fires when callback timing skips past 30/60/... whole seconds. + current_bucket = elapsed_seconds // 30 + if current_bucket >= 1 and current_bucket != self._last_timeout_log_second: + self._last_timeout_log_second = current_bucket + LOGGER.debug("Current analysis callback time %s", time_diff) + return False def _addStringsToReport(self, smda_report, buffer, mode=None): smda_report.buffer = buffer for smda_function in smda_report.getFunctions(): + if smda_report.architecture == "dalvik": + if smda_function.stringrefs and isinstance(smda_function.stringrefs, dict): + smda_function.stringrefs = [ + { + "string": string_value, + "ins_addr": referencing_addr, + "data_addr": None, + "type": "dex", + } + for referencing_addr, string_value in sorted(smda_function.stringrefs.items()) + ] + continue function_strings = [] for string_result in extract_strings(smda_function, mode=mode): string, referencing_addr, string_addr, string_type = string_result @@ -150,6 +178,16 @@ def disassembleBuffer( Disassemble a given buffer (file_content), with given base_addr. Optionally specify bitness, the areas to which disassembly should be limited to (code_areas) and an entry point (oep) """ + # Auto-detect DEX when the caller did not explicitly override architecture. + # disassembleUnmappedBuffer / disassembleFile already use FileLoader for detection; + # this path bypasses it, so we check the magic bytes manually here. + if architecture == "intel" and DexFileLoader.isCompatible(file_content): + architecture = "dalvik" + if bitness is None: + bitness = DexFileLoader.getBitness(file_content) + # initDisassembler caches by self.disassembler-is-None, so a backend + # picked at construction time would otherwise win against autodetect. + self.disassembler = None binary_info = BinaryInfo(file_content) binary_info.base_addr = base_addr binary_info.bitness = bitness @@ -176,6 +214,7 @@ def disassembleBuffer( def _disassemble(self, binary_info, timeout=0): self._start_time = datetime.datetime.now(datetime.timezone.utc) self._timeout = timeout + self._last_timeout_log_second = -1 self._ensureHashes(binary_info) if self.disassembler: self.disassembly = self.disassembler.analyzeBuffer(binary_info, self._callbackAnalysisTimeout) diff --git a/smda/DisassemblyResult.py b/smda/DisassemblyResult.py index 0210f5f..545e3fb 100644 --- a/smda/DisassemblyResult.py +++ b/smda/DisassemblyResult.py @@ -24,6 +24,7 @@ def __init__(self): self.exported_functions = set() self.failed_analysis_addr = [] self.function_borders = {} + self.function_metadata = {} # stored as key: int(i.address) = (i.size, i.mnemonic, i.op_str) self.instructions = {} self.ins2fn = {} @@ -127,6 +128,49 @@ def getMnemonic(self, instruction_addr): return self.instructions[instruction_addr][0] return "" + def _getContainingBlockStart(self, blocks, instruction_addr): + for block in blocks: + if not block: + continue + block_start = block[0][0] + block_end = block[-1][0] + block[-1][1] + if block_start <= instruction_addr < block_end: + return block_start + return None + + def _getExceptionSuccessors(self, func_addr, blocks): + metadata = self.function_metadata.get(func_addr, {}) + try_ranges = metadata.get("try_ranges", []) + if not try_ranges: + return {} + block_successors = {} + for try_range in try_ranges: + raw_targets = [] + for handler in try_range.get("handlers", []): + target_addr = handler.get("target_addr") if isinstance(handler, dict) else None + if target_addr is not None: + raw_targets.append(target_addr) + if try_range.get("catch_all_addr") is not None: + raw_targets.append(try_range["catch_all_addr"]) + if not raw_targets: + continue + normalized_targets = set() + for target_addr in raw_targets: + block_start = self._getContainingBlockStart(blocks, target_addr) + if block_start is None: + block_start = target_addr + normalized_targets.add(block_start) + for block in blocks: + if not block: + continue + block_start = block[0][0] + block_end = block[-1][0] + block[-1][1] + if try_range["start_addr"] < block_end and block_start < try_range["end_addr"]: + successors = block_successors.get(block_start, set()) + successors.update(normalized_targets) + block_successors[block_start] = successors + return block_successors + def isCode(self, addr): return addr in self.code_map @@ -190,20 +234,26 @@ def removeDataRefs(self, addr_from, addr_to): self.data_refs_to[addr_to] = refs_to def getBlockRefs(self, func_addr): - """blocks refs should stay within function context, thus kill all references outside function""" - block_refs = {} - ins_addrs = set() - for block in self.functions[func_addr]: - for ins in block: - ins_addr = ins[0] - ins_addrs.add(ins_addr) - for block in self.functions[func_addr]: + """Return a normalized intra-function CFG keyed by block start.""" + if func_addr not in self.functions: + return {} + blocks = [block for block in self.functions[func_addr] if block] + block_starts = {block[0][0] for block in blocks} + block_refs = {block_start: [] for block_start in sorted(block_starts)} + for block in blocks: last_ins_addr = block[-1][0] - if last_ins_addr in self.code_refs_from: - verified_refs = sorted(ins_addrs.intersection(self.code_refs_from[last_ins_addr])) - if verified_refs: - block_refs[block[0][0]] = verified_refs - return block_refs + if last_ins_addr not in self.code_refs_from: + continue + verified_refs = sorted(block_starts.intersection(self.code_refs_from[last_ins_addr])) + if verified_refs: + block_refs[block[0][0]] = verified_refs + for block_start, successors in self._getExceptionSuccessors(func_addr, blocks).items(): + merged_successors = set(block_refs.get(block_start, [])) + merged_successors.update(successors) + block_refs[block_start] = sorted(merged_successors) + for successor in successors: + block_refs.setdefault(successor, []) + return {block_start: block_refs[block_start] for block_start in sorted(block_refs)} def getInRefs(self, func_addr): in_refs = [] diff --git a/smda/common/BinaryInfo.py b/smda/common/BinaryInfo.py index b65a72d..a1b0205 100644 --- a/smda/common/BinaryInfo.py +++ b/smda/common/BinaryInfo.py @@ -139,4 +139,6 @@ def getHeaderBytes(self): return self.raw_data[:0x400] elif isinstance(lief_result, lief.ELF.Binary): return self.raw_data[:0x40] + elif self.architecture == "dalvik" or self.raw_data[:4] == b"dex\n": + return self.raw_data[:0x70] return None diff --git a/smda/common/SmdaFunction.py b/smda/common/SmdaFunction.py index c113ae9..537fee7 100644 --- a/smda/common/SmdaFunction.py +++ b/smda/common/SmdaFunction.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import hashlib +import logging import re import struct from typing import Iterator @@ -11,6 +12,8 @@ from .SmdaInstruction import SmdaInstruction +LOGGER = logging.getLogger(__name__) + class SmdaFunction: smda_report = None @@ -24,6 +27,7 @@ class SmdaFunction: code_inrefs = None code_outrefs = None is_exported = None + architecture_metadata = None # metadata binweight = 0 characteristics = "" @@ -45,6 +49,8 @@ def __init__(self, disassembly=None, function_offset=None, config=None, smda_rep self.inrefs = disassembly.getInRefs(function_offset) self.outrefs = disassembly.getOutRefs(function_offset) self.is_exported = self.offset in disassembly.exported_functions + self.architecture_metadata = disassembly.function_metadata.get(function_offset, {}) + self.blockrefs = self.getNormalizedBlockRefs() # metadata self.function_name = disassembly.function_symbols.get(function_offset, "") self.characteristics = ( @@ -62,8 +68,22 @@ def __init__(self, disassembly=None, function_offset=None, config=None, smda_rep if function_offset in disassembly.candidates else None ) - if config and config.WITH_STRINGS: - self.stringrefs = disassembly.getStringRefsForFunction(function_offset) + # DEX strings are part of the parsed file structure, so they're always + # populated for Dalvik regardless of WITH_STRINGS — no extra extraction + # cost. For other architectures, honor WITH_STRINGS as usual. + if ( + config + and config.WITH_STRINGS + or ( + disassembly.binary_info.architecture == "dalvik" + and disassembly.getStringRefsForFunction(function_offset) + ) + ): + self.stringrefs = ( + self._normalizeDalvikStringRefs(disassembly.getStringRefsForFunction(function_offset)) + if disassembly.binary_info.architecture == "dalvik" + else disassembly.getStringRefsForFunction(function_offset) + ) if config and config.CALCULATE_HASHING: self.pic_hash = self.getPicHash(disassembly.binary_info) if config and config.CALCULATE_SCC: @@ -93,10 +113,16 @@ def num_instructions(self): @property def num_calls(self): + architecture = self.smda_report.architecture if self.smda_report else "" + if architecture == "dalvik": + return sum([1 for ins in self.getInstructions() if ins.mnemonic.startswith("invoke-")]) return sum([1 for ins in self.getInstructions() if ins.mnemonic == "call"]) @property def num_returns(self): + architecture = self.smda_report.architecture if self.smda_report else "" + if architecture == "dalvik": + return sum([1 for ins in self.getInstructions() if ins.mnemonic.startswith("return")]) return sum([1 for ins in self.getInstructions() if ins.mnemonic in ["ret", "retn"]]) def isApiThunk(self): @@ -142,17 +168,19 @@ def getCodeOutrefs(self): yield from self.code_outrefs def _calculateSccs(self): - tarjan = Tarjan(self.blockrefs) + tarjan = Tarjan(self.getNormalizedBlockRefs()) tarjan.calculateScc() return tarjan.getResult() def _calculateNestingDepth(self): nesting_depth = 0 try: - if self.blockrefs: - tree = build_dominator_tree(self.blockrefs, self.offset) + normalized_blockrefs = self.getNormalizedBlockRefs() + root = self._getCfgRoot(normalized_blockrefs) + if normalized_blockrefs and root is not None: + tree = build_dominator_tree(normalized_blockrefs, root) if tree: - nesting_depth = get_nesting_depth(self.blockrefs, tree, self.offset) + nesting_depth = get_nesting_depth(normalized_blockrefs, tree, root) except Exception: pass return nesting_depth @@ -191,6 +219,95 @@ def _parseBlocks(self, block_dict): self.blocks[int(offset)] = instructions self.binweight += sum([len(ins.bytes) / 2 for ins in instructions]) + @staticmethod + def _normalizeDalvikStringRefs(stringrefs): + if not stringrefs: + return [] + if isinstance(stringrefs, list): + normalized = [] + for entry in stringrefs: + if isinstance(entry, dict): + normalized.append( + { + "string": entry.get("string", ""), + "ins_addr": int(entry.get("ins_addr", 0)), + "data_addr": entry.get("data_addr", None), + "type": entry.get("type", "dex"), + } + ) + return normalized + if isinstance(stringrefs, dict): + return [ + { + "string": string_value, + "ins_addr": int(referencing_addr), + "data_addr": None, + "type": "dex", + } + for referencing_addr, string_value in sorted(stringrefs.items()) + ] + return stringrefs + + def _getContainingBlockStart(self, instruction_addr): + for block_start, block in self.blocks.items(): + if not block: + continue + block_end = block[-1].offset + (len(block[-1].bytes) // 2) + if block_start <= instruction_addr < block_end: + return block_start + return None + + def _getCfgRoot(self, normalized_blockrefs): + if self.offset in normalized_blockrefs: + return self.offset + block_start = self._getContainingBlockStart(self.offset) + if block_start is not None: + return block_start + if normalized_blockrefs: + # No entry block found for self.offset — refuse to fabricate a root, + # since dominator/nesting derived from a wrong root is silently misleading. + LOGGER.warning( + "Normalized CFG for %s (0x%x) has no entry block; skipping root-dependent analysis.", + self.function_name or "", + self.offset, + ) + return None + LOGGER.warning("Normalized CFG for %s (0x%x) is empty.", self.function_name or "", self.offset) + return None + + def getNormalizedBlockRefs(self): + current_blockrefs = self.blockrefs or {} + normalized_blockrefs = { + block_start: sorted(current_blockrefs.get(block_start, [])) for block_start in self.blocks + } + try_ranges = self.architecture_metadata.get("try_ranges", []) if self.architecture_metadata else [] + for try_range in try_ranges: + raw_targets = [] + for handler in try_range.get("handlers", []): + target_addr = handler.get("target_addr") if isinstance(handler, dict) else None + if target_addr is not None: + raw_targets.append(target_addr) + if try_range.get("catch_all_addr") is not None: + raw_targets.append(try_range["catch_all_addr"]) + if not raw_targets: + continue + normalized_targets = set() + for target_addr in raw_targets: + block_start = self._getContainingBlockStart(target_addr) + if block_start is None: + block_start = target_addr + normalized_targets.add(block_start) + normalized_blockrefs.setdefault(block_start, []) + for block_start, block in self.blocks.items(): + if not block: + continue + block_end = block[-1].offset + (len(block[-1].bytes) // 2) + if try_range["start_addr"] < block_end and block_start < try_range["end_addr"]: + successors = set(normalized_blockrefs.get(block_start, [])) + successors.update(normalized_targets) + normalized_blockrefs[block_start] = sorted(successors) + return {block_start: normalized_blockrefs[block_start] for block_start in sorted(normalized_blockrefs)} + def toDotGraph(self, with_api=False): dot_graph = f'digraph "CFG for 0x{self.offset:x}" {{\n' dot_graph += f' label="CFG for 0x{self.offset:x}";\n' @@ -229,6 +346,8 @@ def fromDict(cls, function_dict, binary_info=None, version=None, smda_report=Non smda_function.outrefs = {int(k): v for k, v in function_dict["outrefs"].items()} # provide some legacy support by assuming functions are not exported for SMDA reports < 1.7.0 smda_function.is_exported = function_dict.get("is_exported", False) + smda_function.architecture_metadata = function_dict.get("architecture_metadata", {}) + smda_function.blockrefs = smda_function.getNormalizedBlockRefs() smda_function.binweight = function_dict["metadata"]["binweight"] smda_function.characteristics = function_dict["metadata"]["characteristics"] smda_function.confidence = function_dict["metadata"]["confidence"] @@ -236,7 +355,16 @@ def fromDict(cls, function_dict, binary_info=None, version=None, smda_report=Non smda_function.pic_hash = function_dict["metadata"].get("pic_hash", None) smda_function.strongly_connected_components = function_dict["metadata"]["strongly_connected_components"] smda_function.tfidf = function_dict["metadata"]["tfidf"] - smda_function.stringrefs = function_dict.get("stringrefs", {}) + stringrefs = function_dict.get("stringrefs", {}) + function_architecture = None + if smda_report is not None: + function_architecture = smda_report.architecture + elif binary_info is not None: + function_architecture = binary_info.architecture + if function_architecture == "dalvik": + smda_function.stringrefs = smda_function._normalizeDalvikStringRefs(stringrefs) + else: + smda_function.stringrefs = stringrefs if binary_info and binary_info.architecture: smda_function._escaper = IntelInstructionEscaper if binary_info.architecture in ["intel"] else None else: @@ -278,6 +406,7 @@ def toDict(self) -> dict: "inrefs": self.inrefs, "outrefs": self.outrefs, "is_exported": self.is_exported, + "architecture_metadata": self.architecture_metadata if self.architecture_metadata is not None else {}, "metadata": { "binweight": self.binweight, "characteristics": self.characteristics, diff --git a/smda/common/SmdaInstruction.py b/smda/common/SmdaInstruction.py index 257f8bb..6bb0809 100644 --- a/smda/common/SmdaInstruction.py +++ b/smda/common/SmdaInstruction.py @@ -25,6 +25,14 @@ def __init__(self, ins_list=None, smda_function=None): self.operands = ins_list[3] def getDataRefs(self): + emitted = set() + smda_report = self.smda_function.smda_report + if smda_report.data_refs_from is not None and self.offset in smda_report.data_refs_from: + for value in sorted(smda_report.data_refs_from[self.offset]): + emitted.add(value) + yield value + if smda_report.architecture != "intel": + return if self.getMnemonicGroup(IntelInstructionEscaper) != "C": detailed = self.getDetailed() if len(detailed.operands) > 0: @@ -37,10 +45,18 @@ def getDataRefs(self): if detailed.reg_name(i.mem.base) == "rip": # add RIP value value += detailed.address + detailed.size - if value is not None and self.smda_function.smda_report.isAddrWithinMemoryImage(value): + if ( + value is not None + and value not in emitted + and self.smda_function.smda_report.isAddrWithinMemoryImage(value) + ): + emitted.add(value) yield value def getDetailed(self): + arch = self.smda_function.smda_report.architecture + if arch is not None and arch != "intel": + raise NotImplementedError(f"getDetailed() is only available for Intel architecture, not '{arch}'") if self.detailed is None: capstone = self.smda_function.smda_report.getCapstone() with_details = list(capstone.disasm(bytes.fromhex(self.bytes), self.offset)) diff --git a/smda/common/SmdaReport.py b/smda/common/SmdaReport.py index c41246b..07a3367 100644 --- a/smda/common/SmdaReport.py +++ b/smda/common/SmdaReport.py @@ -50,6 +50,8 @@ class SmdaReport: version = None xcfg = None xheader = None + data_refs_from = None + data_refs_to = None # on first usage, initialize codexrefs objects for all functions based on inrefs/outrefs (requires knowledge about all functions) _has_codexrefs = False @@ -92,6 +94,8 @@ def __init__(self, disassembly=None, config=None, buffer=None): self.version = disassembly.binary_info.version self.xcfg = self._convertCfg(disassembly, config=config) self.xheader = disassembly.binary_info.getHeaderBytes() + self.data_refs_from = {src: sorted(dst) for src, dst in disassembly.data_refs_from.items()} + self.data_refs_to = {dst: sorted(src) for dst, src in disassembly.data_refs_to.items()} self.xmetadata = { "exported_functions": disassembly.binary_info.getExportedFunctions(), "imported_functions": disassembly.binary_info.getImportedFunctions(), @@ -262,6 +266,8 @@ def fromDict(cls, report_dict) -> Optional["SmdaReport"]: smda_report.statistics = DisassemblyStatistics.fromDict(report_dict["statistics"]) smda_report.status = report_dict["status"] smda_report.timestamp = datetime.datetime.strptime(report_dict["timestamp"], "%Y-%m-%dT%H-%M-%S") + smda_report.data_refs_from = {int(k): v for k, v in report_dict.get("data_refs_from", {}).items()} + smda_report.data_refs_to = {int(k): v for k, v in report_dict.get("data_refs_to", {}).items()} binary_info = BinaryInfo(b"") binary_info.architecture = smda_report.architecture binary_info.abi = smda_report.abi @@ -321,6 +327,8 @@ def toDict(self) -> dict: "timestamp": self.timestamp.strftime("%Y-%m-%dT%H-%M-%S"), "xcfg": {function_addr: smda_function.toDict() for function_addr, smda_function in self.xcfg.items()}, "xheader": self.xheader.hex() if self.xheader else "", + "data_refs_from": self.data_refs_from if self.data_refs_from is not None else {}, + "data_refs_to": self.data_refs_to if self.data_refs_to is not None else {}, "xmetadata": self.xmetadata, } @@ -332,4 +340,5 @@ def toFile(self, output_filepath) -> None: def __str__(self): if self.status == "error": return f"{self.execution_time:>6.3f}s -> {self.message}" - return f"{self.execution_time:>6.3f}s -> (architecture: {self.architecture}.{self.bitness}bit, base_addr: 0x{self.base_addr:08x}): {len(self.xcfg)} functions" + arch_str = f"{self.architecture}.{self.bitness}bit" if self.bitness else self.architecture + return f"{self.execution_time:>6.3f}s -> (architecture: {arch_str}, base_addr: 0x{self.base_addr:08x}): {len(self.xcfg)} functions" diff --git a/smda/dalvik/DalvikDisassembler.py b/smda/dalvik/DalvikDisassembler.py new file mode 100644 index 0000000..057f70d --- /dev/null +++ b/smda/dalvik/DalvikDisassembler.py @@ -0,0 +1,973 @@ +import contextlib +import datetime +import logging +import struct + +import lief + +from smda.dalvik.DalvikFunctionAnalysisState import DalvikFunctionAnalysisState +from smda.dalvik.DalvikOpcodeDecoder import ( + decode_instruction, + parse_code_item_header, + read_sleb128, + read_uleb128, +) +from smda.DisassemblyResult import DisassemblyResult +from smda.utility.DexFileLoader import DexFileLoader + +LOGGER = logging.getLogger(__name__) + + +class DexReferenceResolver: + PRIMITIVE_TYPES = { + "VOID_T": "V", + "VOID": "V", + "BOOLEAN": "Z", + "BYTE": "B", + "SHORT": "S", + "CHAR": "C", + "INT": "I", + "LONG": "J", + "FLOAT": "F", + "DOUBLE": "D", + } + STRING_PRIMITIVE_TYPES = { + "void": "V", + "boolean": "Z", + "byte": "B", + "short": "S", + "char": "C", + "int": "I", + "long": "J", + "float": "F", + "double": "D", + } + ACCESS_FLAG_NAMES = [ + (0x0001, "public"), + (0x0002, "private"), + (0x0004, "protected"), + (0x0008, "static"), + (0x0010, "final"), + (0x0020, "synchronized"), + (0x0040, "bridge"), + (0x0080, "varargs"), + (0x0100, "native"), + (0x0200, "interface"), + (0x0400, "abstract"), + (0x0800, "strict"), + (0x1000, "synthetic"), + (0x2000, "annotation"), + (0x4000, "enum"), + (0x8000, "unused"), + (0x10000, "constructor"), + (0x20000, "declared-synchronized"), + ] + + def __init__(self, dex_file): + self.dex_file = dex_file + self.strings = list(getattr(dex_file, "strings", [])) + self.methods = self._indexItems(getattr(dex_file, "methods", [])) + self.fields = self._indexItems(getattr(dex_file, "fields", [])) + self.types = self._indexItems(getattr(dex_file, "types", [])) + self.prototypes = self._indexItems(getattr(dex_file, "prototypes", [])) + self.classes = self._indexItems(getattr(dex_file, "classes", [])) + + def _indexItems(self, items): + indexed = {} + for index, item in enumerate(items): + indexed[getattr(item, "index", index)] = item + return indexed + + def _safeGet(self, collection, index): + if index in collection: + return collection[index] + return None + + def _safeAttr(self, obj, attr, default=None): + try: + return getattr(obj, attr) + except Exception: + return default + + def _normalizeTypeString(self, type_name): + if not type_name: + return None + if type_name in {"V", "Z", "B", "S", "C", "I", "J", "F", "D"}: + return type_name + if type_name.startswith("[") or (type_name.startswith("L") and type_name.endswith(";")): + return type_name + lowered = type_name.lower() + if lowered in self.STRING_PRIMITIVE_TYPES: + return self.STRING_PRIMITIVE_TYPES[lowered] + if "." in type_name and "/" not in type_name: + return f"L{type_name.replace('.', '/')};" + return type_name + + def _formatType(self, type_obj): + if type_obj is None: + return "" + if isinstance(type_obj, str): + return self._normalizeTypeString(type_obj) + fullname = self._safeAttr(type_obj, "fullname", None) + if fullname: + return self._normalizeTypeString(fullname) + value = self._safeAttr(type_obj, "value", None) + if value is not None: + fullname = self._safeAttr(value, "fullname", None) + if fullname: + return self._normalizeTypeString(fullname) + primitive_name = self.PRIMITIVE_TYPES.get(self._safeAttr(value, "name", ""), None) + if primitive_name: + return primitive_name + name = self._safeAttr(value, "name", None) + if name: + return self._normalizeTypeString(name) + with contextlib.suppress(Exception): + normalized = self._normalizeTypeString(str(value)) + if normalized and not normalized.startswith("" + + def _formatProto(self, prototype): + if prototype is None: + return "()" + params = [] + for param in getattr(prototype, "parameters_type", []): + params.append(self._formatType(param)) + return_type = self._formatType(getattr(prototype, "return_type", None)) + return f"({''.join(params)}){return_type}" + + def formatMethod(self, method): + if method is None: + return "method@" + class_name = self._formatType(getattr(method, "cls", None)) + method_name = getattr(method, "name", "") + prototype = self._formatProto(getattr(method, "prototype", None)) + return f"{class_name}->{method_name}{prototype}" + + def formatField(self, field): + if field is None: + return "field@" + class_name = self._formatType(getattr(field, "cls", None)) + field_name = getattr(field, "name", "") + field_type = self._formatType(getattr(field, "type", None)) + return f"{class_name}->{field_name}:{field_type}" + + def formatProto(self, index): + prototype = self._safeGet(self.prototypes, index) + if prototype is None: + return f"proto@{index}" + return self._formatProto(prototype) + + def formatTypeByIndex(self, index): + type_obj = self._safeGet(self.types, index) + if type_obj is None: + return f"type@{index}" + return self._formatType(type_obj) + + # Baksmali-style escape map for DEX string literals. + _STRING_ESCAPE_MAP = { + '"': '\\"', + "\\": "\\\\", + "\n": "\\n", + "\r": "\\r", + "\t": "\\t", + "\0": "\\0", + } + + @classmethod + def _escapeDexString(cls, s): + parts = [] + for ch in s: + escaped = cls._STRING_ESCAPE_MAP.get(ch) + if escaped: + parts.append(escaped) + elif ch.isprintable(): + parts.append(ch) + else: + parts.append(f"\\u{ord(ch):04x}") + return "".join(parts) + + def formatRef(self, ref_kind, index): + if ref_kind == "string": + if 0 <= index < len(self.strings): + return '"' + self._escapeDexString(self.strings[index]) + '"' + return f"string@{index}" + if ref_kind == "type": + return self.formatTypeByIndex(index) + if ref_kind == "field": + return self.formatField(self._safeGet(self.fields, index)) + if ref_kind == "method": + return self.formatMethod(self._safeGet(self.methods, index)) + if ref_kind == "proto": + return self.formatProto(index) + if ref_kind == "method_handle": + return f"method_handle@{index}" + if ref_kind == "call_site": + return f"call_site@{index}" + return f"{ref_kind}@{index}" if ref_kind else f"item@{index}" + + def getMethod(self, method_index): + return self._safeGet(self.methods, method_index) + + def getMethodTarget(self, method_index): + method = self.getMethod(method_index) + if method is None: + return None, None + code_offset = getattr(method, "code_offset", 0) + code_info = getattr(method, "code_info", None) + if code_offset and code_info: + return code_offset, self.formatMethod(method) + return None, self.formatMethod(method) + + def getStringValue(self, string_index): + if 0 <= string_index < len(self.strings): + return self.strings[string_index] + return None + + def getMethodMetadata(self, method): + access_flags = getattr(method, "access_flags", 0) + access_flags = getattr(access_flags, "value", access_flags) + if isinstance(access_flags, (list, tuple, set)): + normalized_flags = 0 + for flag in access_flags: + normalized_flags |= getattr(flag, "value", 0) + access_flags = normalized_flags + access_flag_names = [name for mask, name in self.ACCESS_FLAG_NAMES if access_flags & mask] + method_name = self.formatMethod(method) + return { + "method_name": method_name, + "class_name": self._formatType(getattr(method, "cls", None)), + "prototype": self._formatProto(getattr(method, "prototype", None)), + "access_flags": access_flags, + "access_flags_decoded": access_flag_names, + } + + +class DalvikDisassembler: + MAX_SWITCH_TARGETS_FOR_HEURISTIC = 32 + + def __init__(self, config): + self.config = config + self.disassembly = DisassemblyResult() + self.disassembly.smda_version = config.VERSION + self._diag_stubs_suppressed = 0 + + def addPdbFile(self, binary_info, pdb_path): + return + + def _formatReferenceCounts(self, reference_counts): + active_counts = [f"{ref_kind}={count}" for ref_kind, count in reference_counts.items() if count] + return ", ".join(active_counts) if active_counts else "none" + + def _summarizeHeuristics(self): + heuristic_counts = {} + for metadata in self.disassembly.function_metadata.values(): + for heuristic in metadata.get("heuristics", []): + heuristic_counts[heuristic] = heuristic_counts.get(heuristic, 0) + 1 + if not heuristic_counts: + return "none" + return ", ".join(f"{name}={heuristic_counts[name]}" for name in sorted(heuristic_counts)) + + def _logMethodDiagnostics(self, state): + metadata = state.metadata + heuristics = metadata["heuristics"] + n_ins = len(state.instructions) + n_blocks = len(self.disassembly.functions.get(state.start_addr, [])) + n_handlers = metadata.get("exception_handler_count", 0) + + # Suppress trivial single-block stubs with no heuristics — typically generated + # string-obfuscation shims. Count them for the summary instead. + if n_ins <= 8 and n_blocks <= 1 and n_handlers == 0 and not heuristics: + self._diag_stubs_suppressed += 1 + return + + heuristic_str = ", ".join(heuristics) if heuristics else "none" + outref_count = len(self.disassembly.getOutRefs(state.start_addr)) + string_ref_count = len(self.disassembly.getStringRefsForFunction(state.start_addr)) + fmt = "0x%08x %s: ins=%d blocks=%d outrefs=%d handlers=%d strings=%d refs=[%s] heuristics=[%s]" + args = ( + state.start_addr, + state.label, + n_ins, + n_blocks, + outref_count, + n_handlers, + string_ref_count, + self._formatReferenceCounts(metadata["reference_counts"]), + heuristic_str, + ) + # Both tiers logged at DEBUG — per-method detail belongs in verbose mode only. + # Heuristic summary is surfaced via the INFO-level analysis banner instead. + prefix = "[!] " if heuristics else "" + LOGGER.debug(prefix + fmt, *args) + + def _logAnalysisSummary(self, version, method_counts, analyzed_count): + total_blocks = sum(len(blocks) for blocks in self.disassembly.functions.values()) + total_strings = sum(len(stringrefs) for stringrefs in self.disassembly.stringrefs.values()) + try_functions = sum( + 1 for metadata in self.disassembly.function_metadata.values() if metadata.get("exception_handler_count", 0) + ) + heuristics_str = self._summarizeHeuristics() + sep = "-" * 68 + LOGGER.info(sep) + LOGGER.info( + "DEX v%s | analyzed %d/%d | functions=%d blocks=%d instructions=%d", + version, + analyzed_count, + method_counts["total"], + len(self.disassembly.functions), + total_blocks, + len(self.disassembly.instructions), + ) + LOGGER.info( + " api_refs=%d string_refs=%d try_blocks=%d failed=%d", + len(self.disassembly.addr_to_api), + total_strings, + try_functions, + len(self.disassembly.failed_analysis_addr), + ) + skip_nc = method_counts["skipped_no_class"] + skip_nc2 = method_counts["skipped_no_code"] + skip_inv = method_counts["skipped_invalid_offset"] + if skip_nc + skip_nc2 + skip_inv: + LOGGER.info( + " skipped: no_class=%d no_code=%d invalid_offset=%d", + skip_nc, + skip_nc2, + skip_inv, + ) + if self._diag_stubs_suppressed: + LOGGER.info(" (%d trivial stubs suppressed from per-method log)", self._diag_stubs_suppressed) + if heuristics_str: + LOGGER.info("[!] heuristics: %s", heuristics_str) + LOGGER.info(sep) + + def _getPayloadSize(self, bytecode, idx): + if idx < 0 or idx + 2 > len(bytecode): + return 0 + max_possible = len(bytecode) - idx + ident = struct.unpack_from(" len(bytecode): + return 0 + size = struct.unpack_from(" len(bytecode): + return 0 + size = struct.unpack_from(" len(bytecode): + return 0 + element_width = struct.unpack_from(" len(bytecode): + return [] + ident = struct.unpack_from(" len(bytecode): + break + rel_offset = struct.unpack_from(" len(bytecode): + break + rel_offset = struct.unpack_from(" len(raw_data): + raise ValueError("Invalid try_item table") + handlers_offset = tries_offset + tries_size * 8 + encoded_handler_count, cursor = read_uleb128(raw_data, handlers_offset) + handlers_by_offset = {} + structural_violations = [] + for _ in range(encoded_handler_count): + handler_relative = cursor - handlers_offset + encoded_size, cursor = read_sleb128(raw_data, cursor) + catch_all_addr = None + handlers = [] + for _ in range(abs(encoded_size)): + type_idx, cursor = read_uleb128(raw_data, cursor) + addr, cursor = read_uleb128(raw_data, cursor) + handlers.append({"type_idx": type_idx, "addr_units": addr}) + if encoded_size <= 0: + catch_all_addr, cursor = read_uleb128(raw_data, cursor) + handlers_by_offset[handler_relative] = { + "handlers": handlers, + "catch_all_addr": catch_all_addr, + } + + try_items = [] + for index in range(tries_size): + start_addr_units, insn_count_units, handler_off = struct.unpack_from( + "", + "target_addr": catch_all_addr, + "protected_range_start": try_block["start_addr"], + "protected_range_end": try_block["end_addr"], + } + ) + try_ranges.append( + { + "start_addr": try_block["start_addr"], + "end_addr": try_block["end_addr"], + "handlers": handlers, + "handler_targets": handler_targets, + "catch_all_addr": catch_all_addr, + } + ) + metadata.update( + { + "registers_size": code_item_header["registers_size"], + "ins_size": code_item_header["ins_size"], + "outs_size": code_item_header["outs_size"], + "tries_size": code_item_header["tries_size"], + "debug_info_off": code_item_header["debug_info_off"], + "insns_size_units": code_item_header["insns_size"], + "exception_handler_count": len(exception_handlers), + "exception_handlers": exception_handlers, + "try_ranges": try_ranges, + "heuristics": [], + "reference_counts": { + "string": 0, + "type": 0, + "field": 0, + "method": 0, + "proto": 0, + "call_site": 0, + "method_handle": 0, + }, + "structural_violations": [], + } + ) + return metadata + + def _updateHeuristics(self, metadata, decoded, payload_size): + heuristics = metadata["heuristics"] + if metadata["tries_size"] > 0 and "exception-obfuscation-surface" not in heuristics: + heuristics.append("exception-obfuscation-surface") + if ( + decoded.mnemonic + in { + "invoke-custom", + "invoke-custom/range", + "invoke-polymorphic", + "invoke-polymorphic/range", + } + and "advanced-dispatch" not in heuristics + ): + heuristics.append("advanced-dispatch") + if ( + decoded.payload_kind in {"packed-switch", "sparse-switch"} + and payload_size > self.MAX_SWITCH_TARGETS_FOR_HEURISTIC * 4 + and "large-switch-payload" not in heuristics + ): + heuristics.append("large-switch-payload") + if decoded.ref_kind == "method": + operand = decoded.operands + if ( + any(indicator in operand for indicator in ("Ljava/lang/reflect/", "Ljava/lang/Class;->forName")) + and "reflection-hotspot" not in heuristics + ): + heuristics.append("reflection-hotspot") + if ( + any( + indicator in operand + for indicator in ("loadLibrary", "load(", "DexClassLoader", "InMemoryDexClassLoader") + ) + and "native-or-dynamic-loading" not in heuristics + ): + heuristics.append("native-or-dynamic-loading") + if ( + decoded.mnemonic == "const-string" + and metadata["reference_counts"]["string"] >= 3 + and "string-staging" not in heuristics + ): + heuristics.append("string-staging") + + def _buildValidInstructionStarts(self, bytecode): + """Linear sweep (pass 1) that records all legal instruction-start byte offsets. + + This is intentionally cheap: no CFG, no operand resolution, no side effects. + The result is used in the recursive CFG pass (pass 2) to reject externally- + derived targets (switch tables, exception handlers) that land mid-instruction + or outside the method, preventing phantom CFG nodes from adversarial DEX. + """ + valid = set() + payload_ranges = [] + idx = 0 + null_resolve = lambda ref_kind, ref_index: "" # noqa: E731 + # Dalvik instructions are 16-bit aligned (one "code unit"), so always + # advance by 2 on resync — stepping by 1 wastes a decode attempt at every + # odd offset on adversarial input. + while idx < len(bytecode): + if any(start <= idx < end for start, end in payload_ranges): + idx += 2 + continue + try: + decoded = decode_instruction(bytecode, idx, null_resolve) + except ValueError: + idx += 2 + continue + valid.add(idx) + if decoded.payload_idx is not None: + payload_size = self._getPayloadSize(bytecode, decoded.payload_idx) + if payload_size: + payload_ranges.append((decoded.payload_idx, decoded.payload_idx + payload_size)) + idx += decoded.size_bytes + return valid + + def _addStructuralViolation(self, metadata, violation_type, **fields): + violation = {"type": violation_type} + violation.update(fields) + metadata.setdefault("structural_violations", []).append(violation) + + def _sanitizeTryBlocks(self, try_blocks, bytecode_offset, valid_instruction_starts): + sanitized = [] + structural_violations = [] + for try_block in try_blocks: + sanitized_handlers = [] + for handler in try_block["handlers"]: + target_addr = handler["target_addr"] + target_idx = target_addr - bytecode_offset + if target_idx in valid_instruction_starts: + sanitized_handlers.append(handler) + else: + structural_violations.append( + { + "type": "invalid_handler_target", + "protected_range_start": hex(try_block["start_addr"]), + "target": hex(target_addr), + } + ) + catch_all_addr = try_block["catch_all_addr"] + if catch_all_addr is not None and (catch_all_addr - bytecode_offset) not in valid_instruction_starts: + structural_violations.append( + { + "type": "invalid_handler_target", + "protected_range_start": hex(try_block["start_addr"]), + "target": hex(catch_all_addr), + } + ) + catch_all_addr = None + sanitized.append( + { + "start_addr": try_block["start_addr"], + "end_addr": try_block["end_addr"], + "handlers": sanitized_handlers, + "catch_all_addr": catch_all_addr, + } + ) + return sanitized, structural_violations + + def _validateBranchTarget(self, metadata, source_addr, source_idx, target_idx, valid_instruction_starts): + target_addr = source_addr + (target_idx - source_idx) + if target_idx == source_idx: + self._addStructuralViolation( + metadata, + "zero_branch_offset", + from_addr=hex(source_addr), + target=hex(target_addr), + ) + return None + if target_idx not in valid_instruction_starts: + self._addStructuralViolation( + metadata, + "invalid_branch_target", + from_addr=hex(source_addr), + target=hex(target_addr), + ) + return None + return target_addr + + def analyzeFunction(self, dex_file, resolver, method): + start_addr = getattr(method, "code_offset", 0) + raw_data = self.disassembly.binary_info.raw_data + bytecode_offset = start_addr + header_offset = start_addr - 16 + code_item_header = parse_code_item_header(raw_data, header_offset) + insns_size_bytes = code_item_header["insns_size"] * 2 + if bytecode_offset + insns_size_bytes > len(raw_data): + raise ValueError("Invalid Dalvik bytecode range") + + bytecode = raw_data[bytecode_offset : bytecode_offset + insns_size_bytes] + try_blocks, try_violations = self._parseTryBlocks( + raw_data, + resolver, + bytecode_offset, + code_item_header["insns_size"], + code_item_header["tries_size"], + ) + + # Pass 1: build a set of legal instruction-start byte offsets for target validation. + valid_instruction_starts = self._buildValidInstructionStarts(bytecode) + + try_blocks, target_violations = self._sanitizeTryBlocks(try_blocks, bytecode_offset, valid_instruction_starts) + + state = DalvikFunctionAnalysisState(bytecode_offset, self.disassembly) + metadata = self._buildFunctionMetadata(resolver, method, code_item_header, try_blocks) + state.metadata = metadata + metadata["structural_violations"].extend(try_violations) + metadata["structural_violations"].extend(target_violations) + + # Queue exception-handler entry points, validating against instruction boundaries. + for try_block in try_blocks: + for handler in try_block["handlers"]: + target_addr = handler["target_addr"] + state.addBlockStart(target_addr) + state.addBlockToQueue(target_addr) + if try_block["catch_all_addr"] is not None: + ca_addr = try_block["catch_all_addr"] + state.addBlockStart(ca_addr) + state.addBlockToQueue(ca_addr) + + visited_offsets = set() + payload_ranges = [] + + while state.hasUnprocessedBlocks(): + block_start_addr = state.chooseNextBlock() + idx = block_start_addr - bytecode_offset + while 0 <= idx < len(bytecode): + if any(start <= idx < end for start, end in payload_ranges): + break + if idx in visited_offsets: + break + visited_offsets.add(idx) + + try: + decoded = decode_instruction( + bytecode, + idx, + lambda ref_kind, ref_index: self._resolveReference(resolver, ref_kind, ref_index), + ) + except ValueError as exc: + self.disassembly.errors[bytecode_offset + idx] = { + "type": "dalvik_decode_error", + "instruction_bytes": bytecode[idx : idx + 2].hex(), + "message": str(exc), + } + LOGGER.warning("Failed to decode Dalvik instruction at 0x%x: %s", bytecode_offset + idx, exc) + state.decode_error_count += 1 + state.is_partial = True + break + + i_address = bytecode_offset + idx + i_size = decoded.size_bytes + i_mnemonic = decoded.mnemonic + i_op_str = decoded.operands + + if decoded.ref_kind in metadata["reference_counts"]: + metadata["reference_counts"][decoded.ref_kind] += 1 + + state.setNextInstructionReachable(not decoded.is_terminator) + + if decoded.ref_kind == "string" and decoded.ref_index is not None: + string_value = resolver.getStringValue(decoded.ref_index) + if string_value is not None: + self.disassembly.addStringRef(state.start_addr, i_address, string_value) + if decoded.payload_idx is not None: + payload_size = self._getPayloadSize(bytecode, decoded.payload_idx) + if payload_size: + payload_ranges.append((decoded.payload_idx, decoded.payload_idx + payload_size)) + payload_addr = bytecode_offset + decoded.payload_idx + state.addDataRef(i_address, payload_addr, size=payload_size) + if decoded.payload_kind in ("packed-switch", "sparse-switch"): + switch_targets = self._resolveSwitchTargets(bytecode, idx, decoded.payload_idx) + for target_idx in switch_targets: + target_addr = bytecode_offset + target_idx + if target_idx not in valid_instruction_starts: + self._addStructuralViolation( + metadata, + "invalid_switch_target", + from_addr=hex(i_address), + target=hex(target_addr), + ) + continue + state.addCodeRef(i_address, target_addr, by_jump=True) + state.addBlockStart(target_addr) + state.addBlockToQueue(target_addr) + fallthrough = i_address + i_size + state.addBlockStart(fallthrough) + state.addBlockToQueue(fallthrough) + self._updateHeuristics(metadata, decoded, payload_size) + else: + self._updateHeuristics(metadata, decoded, 0) + + if i_mnemonic.startswith("goto"): + target_addr = self._validateBranchTarget( + metadata, i_address, idx, decoded.branch_target_idx, valid_instruction_starts + ) + if target_addr is not None: + state.addCodeRef(i_address, target_addr, by_jump=True) + state.addBlockStart(target_addr) + state.addBlockToQueue(target_addr) + elif decoded.is_conditional and decoded.branch_target_idx is not None: + target_addr = self._validateBranchTarget( + metadata, i_address, idx, decoded.branch_target_idx, valid_instruction_starts + ) + if target_addr is not None: + state.addCodeRef(i_address, target_addr, by_jump=True) + state.addBlockStart(target_addr) + state.addBlockToQueue(target_addr) + fallthrough = i_address + i_size + state.addBlockStart(fallthrough) + state.addBlockToQueue(fallthrough) + elif decoded.is_invoke: + state.setLeaf(False) + call_target = None + call_name = None + if decoded.ref_kind == "method" and decoded.ref_index is not None: + call_target, call_name = resolver.getMethodTarget(decoded.ref_index) + elif decoded.ref_index is not None: + call_name = resolver.formatRef(decoded.ref_kind, decoded.ref_index) + if call_target is not None: + state.addCodeRef(i_address, call_target) + if call_target == state.start_addr: + state.setRecursion(True) + elif call_name: + self._updateApiInformation(i_address, call_name) + + if self._instructionCanThrow(decoded) and try_blocks: + has_exception_edges = self._applyExceptionEdges(state, i_address, decoded, try_blocks) + if has_exception_edges and state.is_next_instruction_reachable: + fallthrough = i_address + i_size + fallthrough_idx = idx + i_size + if fallthrough_idx in valid_instruction_starts: + state.addBlockStart(fallthrough) + state.addBlockToQueue(fallthrough) + + state.addInstruction(i_address, i_size, i_mnemonic, i_op_str, decoded.bytes_) + idx += i_size + if not state.is_next_instruction_reachable: + break + state.endBlock() + + state.label = resolver.formatMethod(method) + state.finalizeAnalysis() + self._logMethodDiagnostics(state) + return state + + def analyzeBuffer(self, binary_info, cbAnalysisTimeout=None): + self._diag_stubs_suppressed = 0 + LOGGER.info("Analyzing buffer with %d bytes @0x%08x", binary_info.binary_size, binary_info.base_addr) + self.disassembly = DisassemblyResult() + self.disassembly.smda_version = self.config.VERSION + self.disassembly.setBinaryInfo(binary_info) + self.disassembly.binary_info.architecture = "dalvik" + self.disassembly.binary_info.bitness = 32 # Dalvik VM is always 32-bit + self.disassembly.binary_info.version = "" + self.disassembly.analysis_start_ts = datetime.datetime.now(datetime.timezone.utc) + self.disassembly.language = "dalvik" + + if not DexFileLoader.isCompatible(binary_info.raw_data): + raise ValueError("Buffer is not a valid DEX file") + + dex_file = None + if getattr(binary_info, "file_path", "") and not getattr(binary_info, "is_buffer", False): + with contextlib.suppress(Exception): + dex_file = lief.DEX.parse(binary_info.file_path) + if dex_file is None: + # Prefer bytes/memoryview: list(raw_data) allocates a PyLong per byte + # (~30x memory blowup on large DEX). Older LIEF builds only accept + # List[int] and either raise TypeError or return None — fall back then. + raw_data = binary_info.raw_data + if not isinstance(raw_data, (bytes, bytearray)): + raw_data = bytes(raw_data) + try: + dex_file = lief.DEX.parse(raw_data) + except TypeError: + dex_file = None + if dex_file is None: + dex_file = lief.DEX.parse(list(raw_data)) + if dex_file is None: + raise ValueError("Failed to parse DEX file") + + self.disassembly.binary_info.version = getattr(dex_file, "version", "") + resolver = DexReferenceResolver(dex_file) + methods = list(dex_file.methods) + method_counts = { + "total": len(methods), + "skipped_no_class": 0, + "skipped_no_code": 0, + "skipped_invalid_offset": 0, + } + sep = "-" * 68 + LOGGER.info(sep) + LOGGER.info( + "DEX v%s | classes=%d methods=%d strings=%d types=%d fields=%d protos=%d", + self.disassembly.binary_info.version, + len(list(getattr(dex_file, "classes", []))), + len(methods), + len(resolver.strings), + len(resolver.types), + len(resolver.fields), + len(resolver.prototypes), + ) + LOGGER.info(sep) + + analyzed_count = 0 + for method in methods: + if cbAnalysisTimeout and cbAnalysisTimeout(): + break + if not getattr(method, "has_class", False): + method_counts["skipped_no_class"] += 1 + continue + if not getattr(method, "code_info", None): + method_counts["skipped_no_code"] += 1 + continue + if getattr(method, "code_offset", 0) < 16: + method_counts["skipped_invalid_offset"] += 1 + continue + try: + self.analyzeFunction(dex_file, resolver, method) + analyzed_count += 1 + except Exception as exc: + LOGGER.warning( + "Failed to analyze Dalvik method %s @0x%x: %s", + resolver.formatMethod(method), + getattr(method, "code_offset", 0), + exc, + ) + method_offset = getattr(method, "code_offset", 0) + self.disassembly.failed_analysis_addr.append(method_offset) + self.disassembly.errors[method_offset] = { + "type": "dalvik_function_error", + "instruction_bytes": "", + "message": str(exc), + } + + self.disassembly.analysis_end_ts = datetime.datetime.now(datetime.timezone.utc) + if cbAnalysisTimeout and cbAnalysisTimeout(): + self.disassembly.analysis_timeout = True + self._logAnalysisSummary(self.disassembly.binary_info.version, method_counts, analyzed_count) + return self.disassembly diff --git a/smda/dalvik/DalvikFunctionAnalysisState.py b/smda/dalvik/DalvikFunctionAnalysisState.py new file mode 100644 index 0000000..7218115 --- /dev/null +++ b/smda/dalvik/DalvikFunctionAnalysisState.py @@ -0,0 +1,284 @@ +import logging + +LOGGER = logging.getLogger(__name__) + +# Dalvik-specific lists for block derivation +# Any invoke-* is a call +CALL_INS = [ + "invoke-virtual", + "invoke-super", + "invoke-direct", + "invoke-static", + "invoke-interface", + "invoke-virtual/range", + "invoke-super/range", + "invoke-direct/range", + "invoke-static/range", + "invoke-interface/range", + "invoke-polymorphic", + "invoke-polymorphic/range", + "invoke-custom", + "invoke-custom/range", +] +# Any return-* or throw is an end instruction +END_INS = ["return-void", "return", "return-wide", "return-object", "throw"] + + +class DalvikFunctionAnalysisState: + def __init__(self, start_addr, disassembly): + self.start_addr = start_addr + self.disassembly = disassembly + self.block_queue = [start_addr] + self.current_block = [] + self.blocks = [] + self.num_blocks_analyzed = 0 + self.instructions = [] + self.instruction_start_bytes = set() + self.processed_blocks = set() + self.processed_bytes = set() + self.jump_targets = set() + self.block_starts = {start_addr} + self.call_register_ins = [] + self.block_start = 0xFFFFFFFF + self.data_bytes = set() + self.data_refs = set() + self.code_refs = set() + self.code_refs_from = {} + self.code_refs_to = {} + self.suspicious_ins_count = 0 + self.is_next_instruction_reachable = True + self.is_block_ending_instruction = False + self.is_sanely_ending = False + self.has_collision = False + self.colliding_addresses = set() + self.is_tailcall_function = False + self.is_leaf_function = True + self.is_recursive = False + self.is_thunk_call = False + self.label = "" + self.metadata = {} + self.decode_error_count = 0 + self.is_partial = False + + def chooseNextBlock(self): + self.is_block_ending_instruction = False + self.block_start = self.block_queue.pop() + self.processed_blocks.update([self.block_start]) + return self.block_start + + def addBlockToQueue(self, block_start): + self.block_starts.add(block_start) + if block_start not in self.processed_blocks: + self.block_queue.append(block_start) + + def addBlockStart(self, block_start): + self.block_starts.add(block_start) + + def endBlock(self): + if self.current_block: + self.num_blocks_analyzed += 1 + self.current_block = [] + + def addInstruction(self, i_address, i_size, i_mnemonic, i_op_str, i_bytes): + ins = (i_address, i_size, i_mnemonic, i_op_str, i_bytes) + self.instructions.append(ins) + self.instruction_start_bytes.add(ins[0]) + self.current_block.append(ins) + for byte in range(i_size): + self.processed_bytes.add(i_address + byte) + if self.is_next_instruction_reachable: + self.addCodeRef(i_address, i_address + i_size) + + def addCodeRef(self, addr_from, addr_to, by_jump=False): + self.code_refs.update([(addr_from, addr_to)]) + refs_from = self.code_refs_from.get(addr_from, set()) + refs_from.update([addr_to]) + self.code_refs_from[addr_from] = refs_from + refs_to = self.code_refs_to.get(addr_to, set()) + refs_to.update([addr_from]) + self.code_refs_to[addr_to] = refs_to + if by_jump: + self.jump_targets.update([addr_to]) + + def removeCodeRef(self, addr_from, addr_to): + if (addr_from, addr_to) in self.code_refs: + self.code_refs.remove((addr_from, addr_to)) + if addr_from in self.code_refs_from and addr_to in self.code_refs_from[addr_from]: + self.code_refs_from[addr_from].remove(addr_to) + if addr_to in self.code_refs_to and addr_from in self.code_refs_to[addr_to]: + self.code_refs_to[addr_to].remove(addr_from) + if addr_to in self.jump_targets: + self.jump_targets.remove(addr_to) + + def addDataRef(self, addr_from, addr_to, size=1): + self.data_refs.update([(addr_from, addr_to)]) + for i in range(size): + self.data_bytes.update([addr_to + i]) + + def backtrackInstructions(self, addr_from, num_instructions): + backtracked = [] + for instruction in sorted(self.instructions, key=lambda x: x[0]): + if instruction[0] >= addr_from: + break + backtracked.append(instruction) + return backtracked[-num_instructions:] + + def _finalizeRegularAnalysis(self): + fn_min = min([ins[0] for ins in self.instructions]) + fn_max = max([ins[0] + ins[1] for ins in self.instructions]) + + if self.is_partial: + self.metadata["partial_disassembly"] = True + self.metadata["decode_error_count"] = self.decode_error_count + + self.disassembly.function_symbols[self.start_addr] = self.label + self.disassembly.function_borders[self.start_addr] = (fn_min, fn_max) + self.disassembly.function_metadata[self.start_addr] = self.metadata + for ins in self.instructions: + self.disassembly.instructions[ins[0]] = (ins[2], ins[1]) + for offset in range(ins[1]): + self.disassembly.code_map[ins[0] + offset] = ins[0] + self.disassembly.ins2fn[ins[0] + offset] = self.start_addr + self.disassembly.data_map.update(self.data_bytes) + self.disassembly.functions[self.start_addr] = self.getBlocks() + for cref in self.code_refs: + self.disassembly.addCodeRefs(cref[0], cref[1]) + for dref in self.data_refs: + self.disassembly.addDataRefs(dref[0], dref[1]) + if self.is_recursive: + self.disassembly.recursive_functions.add(self.start_addr) + if self.is_leaf_function: + self.disassembly.leaf_functions.add(self.start_addr) + if self.is_thunk_call: + self.disassembly.thunk_functions.add(self.start_addr) + + def finalizeAnalysis(self, as_gap=False): + if as_gap: + LOGGER.debug( + "0x%08x had sanity state: %s (%d ins, blocks: %d)", + self.start_addr, + self.is_sanely_ending, + len(self.instructions), + self.num_blocks_analyzed, + ) + if as_gap and not self.is_sanely_ending: + if ( + len(self.instructions) == 1 + and self.instructions[0][2].startswith("goto") + or self.num_blocks_analyzed == 1 + and (self.instructions[-1][2].startswith("goto") or self.instructions[-1][2].startswith("invoke-")) + ): + pass + else: + return False + if self.instructions: + self.num_blocks_analyzed = len(self.getBlocks()) + if self.num_blocks_analyzed: + self._finalizeRegularAnalysis() + return True + + def revertAnalysis(self): + self.disassembly.function_borders.pop(self.start_addr, None) + for ins in self.instructions: + self.disassembly.instructions.pop(ins[0], None) + for byte in range(ins[1]): + self.disassembly.code_map.pop(ins[0] + byte, None) + self.disassembly.ins2fn.pop(ins[0] + byte, None) + for cref in self.code_refs: + self.disassembly.removeCodeRefs(cref[0], cref[1]) + for dref in self.data_refs: + self.disassembly.removeDataRefs(dref[0], dref[1]) + self.disassembly.functions.pop(self.start_addr, None) + + def getBlocks(self): + if self.blocks: + return self.blocks + self.instructions.sort() + ins = {i[0]: ind for ind, i in enumerate(self.instructions)} + potential_starts = {self.start_addr} + potential_starts.update(list(self.jump_targets)) + potential_starts.update(self.block_starts) + blocks = [] + for start in sorted(potential_starts): + if start not in ins: + continue + block = [] + for i in range(ins[start], len(self.instructions)): + current = self.instructions[i] + block.append(current) + if ( + current[0] in self.code_refs_from + and current[2] not in CALL_INS + and i != len(self.instructions) - 1 + and any(r != self.instructions[i + 1][0] for r in self.code_refs_from[current[0]]) + ): + break + if ( + i != len(self.instructions) - 1 + and self.instructions[i + 1][0] in self.code_refs_to + and ( + len(self.code_refs_to[self.instructions[i + 1][0]]) > 1 + or self.instructions[i + 1][0] in potential_starts + ) + ): + break + if current[2] in END_INS: + break + if block: + blocks.append(block) + self.blocks = blocks + return self.blocks + + def hasUnprocessedBlocks(self): + return len(set(self.block_queue).difference(self.processed_blocks)) > 0 + + def isProcessed(self, addr): + return addr in self.processed_bytes + + def isProcessedFunction(self): + return self.start_addr in self.disassembly.code_map + + def isNextInstructionReachable(self): + return self.is_next_instruction_reachable + + def setNextInstructionReachable(self, is_reachable): + self.is_next_instruction_reachable = is_reachable + + def isBlockEndingInstruction(self): + return self.is_block_ending_instruction + + def isFirstInstruction(self): + return len(self.instructions) == 0 + + def setBlockEndingInstruction(self, is_ending): + self.is_block_ending_instruction = is_ending + + def setSanelyEnding(self, is_sanely_ending): + self.is_sanely_ending = is_sanely_ending + + def addCollision(self, colliding_address): + self.has_collision = True + self.colliding_addresses.add(colliding_address) + + def setRecursion(self, is_recursive): + self.is_recursive = is_recursive + + def setThunkCall(self, is_thunk_call): + self.is_thunk_call = is_thunk_call + + def setLeaf(self, is_leaf): + self.is_leaf_function = is_leaf + + def __str__(self): + result = "0x{:x} | current: 0x{:x} | blocks: {} | queue: {} | processed: {} | crefs: {} | drefs: {} | suspicious: {} | ending: {}".format( + self.start_addr, + self.block_start, + len(self.getBlocks()), + ",".join([f"0x{b:x}" for b in sorted(self.block_queue)]), + ",".join([f"0x{b:x}" for b in sorted(self.processed_blocks)]), + len(self.code_refs), + len(self.data_refs), + self.suspicious_ins_count, + self.is_sanely_ending, + ) + return result diff --git a/smda/dalvik/DalvikOpcodeDecoder.py b/smda/dalvik/DalvikOpcodeDecoder.py new file mode 100644 index 0000000..14e56ee --- /dev/null +++ b/smda/dalvik/DalvikOpcodeDecoder.py @@ -0,0 +1,632 @@ +import struct +from dataclasses import dataclass +from typing import Dict, List, Optional + + +def read_uleb128(data, offset): + result = 0 + shift = 0 + current = offset + while current < len(data): + byte = data[current] + result |= (byte & 0x7F) << shift + current += 1 + if byte & 0x80 == 0: + return result, current + shift += 7 + if shift > 35: + break + raise ValueError("Invalid uleb128 encoding") + + +def read_sleb128(data, offset): + result = 0 + shift = 0 + current = offset + size = 32 + while current < len(data): + byte = data[current] + current += 1 + result |= (byte & 0x7F) << shift + shift += 7 + if byte & 0x80 == 0: + if shift < size and byte & 0x40: + result |= -(1 << shift) + return result, current + if shift > 35: + break + raise ValueError("Invalid sleb128 encoding") + + +def parse_code_item_header(raw_data, header_offset): + if header_offset < 0 or header_offset + 16 > len(raw_data): + raise ValueError("Invalid code_item header offset") + ( + registers_size, + ins_size, + outs_size, + tries_size, + debug_info_off, + insns_size, + ) = struct.unpack_from("> 4) & 0x0F + reg_g = raw_bytes[1] & 0x0F + word = int.from_bytes(raw_bytes[4:6], byteorder="little") + reg_c = word & 0x0F + reg_d = (word >> 4) & 0x0F + reg_e = (word >> 8) & 0x0F + reg_f = (word >> 12) & 0x0F + registers = [reg_c, reg_d, reg_e, reg_f, reg_g][:count] + return count, registers + + +def _decode_register_range(count, start): + return list(range(start, start + count)) + + +def _signed(value, bits): + sign_bit = 1 << (bits - 1) + return (value & (sign_bit - 1)) - (value & sign_bit) + + +def decode_instruction(bytecode, byte_idx, resolve_ref): + opcode_value = bytecode[byte_idx] + if opcode_value not in OPCODES: + raise ValueError(f"Unknown Dalvik opcode 0x{opcode_value:02x}") + + opcode = OPCODES[opcode_value] + size_bytes = opcode.size_units * 2 + if byte_idx + size_bytes > len(bytecode): + raise ValueError("Truncated Dalvik instruction") + + raw_bytes = bytes(bytecode[byte_idx : byte_idx + size_bytes]) + registers = [] + operands = "" + literal = None + ref_index = None + ref_index_aux = None + branch_target_idx = None + payload_idx = None + + if opcode.fmt == "10x": + operands = "" + elif opcode.fmt == "10t": + branch_delta = int.from_bytes(raw_bytes[1:2], byteorder="little", signed=True) * 2 + branch_target_idx = byte_idx + branch_delta + operands = f"{hex(branch_target_idx)}" + elif opcode.fmt == "11n": + reg_a = raw_bytes[1] & 0x0F + literal = _signed((raw_bytes[1] >> 4) & 0x0F, 4) + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, #{literal:+d}" + elif opcode.fmt == "11x": + reg_a = raw_bytes[1] + registers = [reg_a] + operands = _reg_name(reg_a) + elif opcode.fmt == "12x": + reg_a = raw_bytes[1] & 0x0F + reg_b = (raw_bytes[1] >> 4) & 0x0F + registers = [reg_a, reg_b] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}" + elif opcode.fmt == "20t": + branch_delta = int.from_bytes(raw_bytes[2:4], byteorder="little", signed=True) * 2 + branch_target_idx = byte_idx + branch_delta + operands = f"{hex(branch_target_idx)}" + elif opcode.fmt == "21c": + reg_a = raw_bytes[1] + ref_index = int.from_bytes(raw_bytes[2:4], byteorder="little") + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, {resolve_ref(opcode.ref_kind, ref_index)}" + elif opcode.fmt == "21h": + reg_a = raw_bytes[1] + # The 16-bit immediate is sign-extended per the Dalvik spec before shifting. + # Using signed=True ensures negative values like 0xFFFF produce -1 << shift + # rather than 0xFFFF << shift, matching baksmali's output. + value = int.from_bytes(raw_bytes[2:4], byteorder="little", signed=True) + shift = 48 if opcode.mnemonic.endswith("wide/high16") else 16 + literal = value << shift + registers = [reg_a] + sign = "-" if literal < 0 else "" + operands = f"{_reg_name(reg_a)}, #{sign}{hex(abs(literal))}" + elif opcode.fmt == "21s": + reg_a = raw_bytes[1] + literal = int.from_bytes(raw_bytes[2:4], byteorder="little", signed=True) + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, #{literal:+d}" + elif opcode.fmt == "21t": + reg_a = raw_bytes[1] + branch_delta = int.from_bytes(raw_bytes[2:4], byteorder="little", signed=True) * 2 + branch_target_idx = byte_idx + branch_delta + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, {hex(branch_target_idx)}" + elif opcode.fmt == "22b": + reg_a = raw_bytes[1] + reg_b = raw_bytes[2] + literal = int.from_bytes(raw_bytes[3:4], byteorder="little", signed=True) + registers = [reg_a, reg_b] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}, #{literal:+d}" + elif opcode.fmt == "22c": + reg_a = raw_bytes[1] & 0x0F + reg_b = (raw_bytes[1] >> 4) & 0x0F + ref_index = int.from_bytes(raw_bytes[2:4], byteorder="little") + registers = [reg_a, reg_b] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}, {resolve_ref(opcode.ref_kind, ref_index)}" + elif opcode.fmt == "22s": + reg_a = raw_bytes[1] & 0x0F + reg_b = (raw_bytes[1] >> 4) & 0x0F + literal = int.from_bytes(raw_bytes[2:4], byteorder="little", signed=True) + registers = [reg_a, reg_b] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}, #{literal:+d}" + elif opcode.fmt == "22t": + reg_a = raw_bytes[1] & 0x0F + reg_b = (raw_bytes[1] >> 4) & 0x0F + branch_delta = int.from_bytes(raw_bytes[2:4], byteorder="little", signed=True) * 2 + branch_target_idx = byte_idx + branch_delta + registers = [reg_a, reg_b] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}, {hex(branch_target_idx)}" + elif opcode.fmt == "22x": + reg_a = raw_bytes[1] + reg_b = int.from_bytes(raw_bytes[2:4], byteorder="little") + registers = [reg_a, reg_b] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}" + elif opcode.fmt == "23x": + reg_a = raw_bytes[1] + reg_b = raw_bytes[2] + reg_c = raw_bytes[3] + registers = [reg_a, reg_b, reg_c] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}, {_reg_name(reg_c)}" + elif opcode.fmt == "30t": + branch_delta = int.from_bytes(raw_bytes[2:6], byteorder="little", signed=True) * 2 + branch_target_idx = byte_idx + branch_delta + operands = f"{hex(branch_target_idx)}" + elif opcode.fmt == "31c": + reg_a = raw_bytes[1] + ref_index = int.from_bytes(raw_bytes[2:6], byteorder="little") + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, {resolve_ref(opcode.ref_kind, ref_index)}" + elif opcode.fmt == "31i": + reg_a = raw_bytes[1] + literal = int.from_bytes(raw_bytes[2:6], byteorder="little", signed=True) + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, #{literal:+d}" + elif opcode.fmt == "31t": + reg_a = raw_bytes[1] + branch_delta = int.from_bytes(raw_bytes[2:6], byteorder="little", signed=True) * 2 + payload_idx = byte_idx + branch_delta + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, payload@{hex(payload_idx)}" + elif opcode.fmt == "32x": + reg_a = int.from_bytes(raw_bytes[2:4], byteorder="little") + reg_b = int.from_bytes(raw_bytes[4:6], byteorder="little") + registers = [reg_a, reg_b] + operands = f"{_reg_name(reg_a)}, {_reg_name(reg_b)}" + elif opcode.fmt == "35c": + _, registers = _decode_register_list_35c(raw_bytes) + ref_index = int.from_bytes(raw_bytes[2:4], byteorder="little") + operands = f"{{{_format_registers(registers)}}}, {resolve_ref(opcode.ref_kind, ref_index)}" + elif opcode.fmt == "3rc": + count = raw_bytes[1] + ref_index = int.from_bytes(raw_bytes[2:4], byteorder="little") + first_reg = int.from_bytes(raw_bytes[4:6], byteorder="little") + registers = _decode_register_range(count, first_reg) + operands = f"{{{_format_registers(registers)}}}, {resolve_ref(opcode.ref_kind, ref_index)}" + elif opcode.fmt == "45cc": + _, registers = _decode_register_list_35c(raw_bytes[:6]) + ref_index = int.from_bytes(raw_bytes[2:4], byteorder="little") + ref_index_aux = int.from_bytes(raw_bytes[6:8], byteorder="little") + operands = "{{{}}}, {}, {}".format( + _format_registers(registers), + resolve_ref(opcode.ref_kind, ref_index), + resolve_ref("proto", ref_index_aux), + ) + elif opcode.fmt == "4rcc": + count = raw_bytes[1] + ref_index = int.from_bytes(raw_bytes[2:4], byteorder="little") + first_reg = int.from_bytes(raw_bytes[4:6], byteorder="little") + ref_index_aux = int.from_bytes(raw_bytes[6:8], byteorder="little") + registers = _decode_register_range(count, first_reg) + operands = "{{{}}}, {}, {}".format( + _format_registers(registers), + resolve_ref(opcode.ref_kind, ref_index), + resolve_ref("proto", ref_index_aux), + ) + elif opcode.fmt == "51l": + reg_a = raw_bytes[1] + literal = int.from_bytes(raw_bytes[2:10], byteorder="little", signed=True) + registers = [reg_a] + operands = f"{_reg_name(reg_a)}, #{hex(literal)}" + else: + raise ValueError(f"Unsupported Dalvik format {opcode.fmt}") + + return DecodedDalvikInstruction( + opcode=opcode_value, + mnemonic=opcode.mnemonic, + fmt=opcode.fmt, + size_units=opcode.size_units, + size_bytes=size_bytes, + bytes_=raw_bytes, + operands=operands, + registers=registers, + literal=literal, + ref_kind=opcode.ref_kind, + ref_index=ref_index, + ref_index_aux=ref_index_aux, + branch_target_idx=branch_target_idx, + payload_idx=payload_idx, + is_invoke=opcode.is_invoke, + is_terminator=opcode.is_terminator, + is_conditional=opcode.is_conditional, + can_throw=opcode.can_throw, + payload_kind=opcode.payload_kind, + ) diff --git a/smda/dalvik/__init__.py b/smda/dalvik/__init__.py new file mode 100644 index 0000000..33b8395 --- /dev/null +++ b/smda/dalvik/__init__.py @@ -0,0 +1,2 @@ +from .DalvikDisassembler import DalvikDisassembler as DalvikDisassembler +from .DalvikFunctionAnalysisState import DalvikFunctionAnalysisState as DalvikFunctionAnalysisState diff --git a/smda/utility/DexFileLoader.py b/smda/utility/DexFileLoader.py new file mode 100644 index 0000000..44ff169 --- /dev/null +++ b/smda/utility/DexFileLoader.py @@ -0,0 +1,83 @@ +import struct + + +class DexFileLoader: + SUPPORTED_VERSIONS = {b"035", b"037", b"038", b"039", b"040"} + HEADER_SIZE = 0x70 + ENDIAN_CONSTANT = 0x12345678 + REVERSE_ENDIAN_CONSTANT = 0x78563412 + + @classmethod + def _parseHeader(cls, data): + if len(data) < cls.HEADER_SIZE: + return None + magic = data[:8] + # Standard DEX (dex\n) or ODEX (dey\n): same structure, same version table + if magic.startswith((b"dex\n", b"dey\n")) and magic[7] == 0: + version = magic[4:7] + if version not in cls.SUPPORTED_VERSIONS: + return None + # CDEX (ART Compact DEX): cdex\0 — reduced validation; LIEF handles details + elif data[:4] == b"cdex": + return {"version": "cdex", "file_size": len(data), "data_off": 0, "data_size": len(data)} + else: + return None + file_size = struct.unpack_from(" len(data): + return None + if endian_tag not in {cls.ENDIAN_CONSTANT, cls.REVERSE_ENDIAN_CONSTANT}: + return None + if map_off and map_off < header_size: + return None + if map_off and map_off >= file_size: + return None + if data_off and data_off > file_size: + return None + if data_size and data_off + data_size > file_size: + return None + return { + "version": version.decode("ascii"), + "file_size": file_size, + "data_off": data_off, + "data_size": data_size, + } + + @classmethod + def isCompatible(cls, data): + return cls._parseHeader(data) is not None + + @staticmethod + def mapBinary(data): + return data + + @staticmethod + def getBaseAddress(data): + return 0 + + @staticmethod + def getBitness(data): + return 32 + + @staticmethod + def getArchitecture(data): + return "dalvik" + + @staticmethod + def getAbi(data): + return "" + + @classmethod + def getCodeAreas(cls, data): + header = cls._parseHeader(data) + if not header: + return [] + if header["data_off"] and header["data_size"]: + return [(header["data_off"], header["data_off"] + header["data_size"])] + return [(0, header["file_size"])] diff --git a/smda/utility/FileLoader.py b/smda/utility/FileLoader.py index 47e08b9..695b7a5 100644 --- a/smda/utility/FileLoader.py +++ b/smda/utility/FileLoader.py @@ -1,6 +1,7 @@ import os from smda.utility.DelphiKbFileLoader import DelphiKbFileLoader +from smda.utility.DexFileLoader import DexFileLoader from smda.utility.ElfFileLoader import ElfFileLoader from smda.utility.MachoFileLoader import MachoFileLoader from smda.utility.PeFileLoader import PeFileLoader @@ -16,7 +17,7 @@ class FileLoader: _abi = "" _architecture = "" _code_areas = [] - file_loaders = [PeFileLoader, ElfFileLoader, MachoFileLoader, DelphiKbFileLoader] + file_loaders = [PeFileLoader, ElfFileLoader, MachoFileLoader, DelphiKbFileLoader, DexFileLoader] def __init__(self, file_path, load_file=True, map_file=False): self._file_path = file_path diff --git a/tests/blockblast_classes_xored b/tests/blockblast_classes_xored new file mode 100644 index 0000000..21fb78e Binary files /dev/null and b/tests/blockblast_classes_xored differ diff --git a/tests/testDalvikDisassembler.py b/tests/testDalvikDisassembler.py new file mode 100644 index 0000000..eea86d3 --- /dev/null +++ b/tests/testDalvikDisassembler.py @@ -0,0 +1,599 @@ +#!/usr/bin/python + +import logging +import os +import struct +import subprocess +import sys +import tempfile +import unittest + +from smda.common.BinaryInfo import BinaryInfo +from smda.common.SmdaFunction import SmdaFunction +from smda.common.SmdaReport import SmdaReport +from smda.dalvik.DalvikOpcodeDecoder import decode_instruction, parse_code_item_header, read_sleb128, read_uleb128 +from smda.Disassembler import Disassembler +from smda.DisassemblyResult import DisassemblyResult +from smda.utility.DexFileLoader import DexFileLoader + +from .context import config + +LOG = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") +logging.disable(logging.CRITICAL) + + +def build_dex_header(version=b"039", file_size=0x70, data_off=0x70, data_size=0): + header = bytearray(0x70) + header[:8] = b"dex\n" + version + b"\x00" + struct.pack_into("method()V" + + def getMethodMetadata(self, method): + return { + "method_name": self.formatMethod(method), + "class_name": "LSynthetic;", + "prototype": "()V", + "access_flags": 0, + "access_flags_decoded": [], + } + + +def build_code_item(insns, tries=None, handlers_blob=b"", registers_size=1, ins_size=0, outs_size=0, debug_info_off=0): + tries = tries or [] + insns_size_units = len(insns) // 2 + header = struct.pack("" in function.function_name for function in functions)) + self.assertTrue(any(function.architecture_metadata for function in functions)) + self.assertTrue(any(function.num_outrefs > 0 for function in functions)) + + any_invoke = False + any_string_ref = False + for function in functions[:250]: + for instruction in function.getInstructions(): + if instruction.mnemonic.startswith("invoke-"): + any_invoke = "->" in instruction.operands or "call_site@" in instruction.operands + if any_invoke: + break + if function.stringrefs: + any_string_ref = True + if any_invoke and any_string_ref: + break + self.assertTrue(any_invoke) + self.assertTrue(any_string_ref) + normalized_invokes = [ + instruction.operands + for function in functions[:250] + for instruction in function.getInstructions() + if instruction.mnemonic.startswith("invoke-") and "->" in instruction.operands + ] + self.assertTrue(normalized_invokes) + self.assertTrue(all(" - " not in operand for operand in normalized_invokes[:50])) + + def testNormalizedBlockRefsPreserveLeafBlocksAndExceptionEdges(self): + func_addr = 0x1000 + disassembly = DisassemblyResult() + disassembly.functions[func_addr] = [ + [ + (0x1000, 2, "invoke-static", "", b"\x6e\x00"), + (0x1002, 2, "move-result-object", "", b"\x0c\x00"), + (0x1004, 2, "return-object", "", b"\x11\x00"), + ], + [ + (0x1010, 2, "move-exception", "", b"\x0d\x00"), + (0x1012, 2, "return-object", "", b"\x11\x00"), + ], + [ + (0x1020, 2, "move-exception", "", b"\x0d\x00"), + (0x1022, 2, "return-object", "", b"\x11\x00"), + ], + [(0x1030, 2, "return-void", "", b"\x0e\x00")], + ] + disassembly.function_metadata[func_addr] = { + "try_ranges": [ + { + "start_addr": 0x1000, + "end_addr": 0x1004, + "handlers": [{"type_idx": 1, "type_name": "Ljava/lang/Exception;", "target_addr": 0x1010}], + "catch_all_addr": 0x1020, + } + ] + } + blockrefs = disassembly.getBlockRefs(func_addr) + self.assertEqual(blockrefs[0x1000], [0x1010, 0x1020]) + self.assertEqual(blockrefs[0x1010], []) + self.assertEqual(blockrefs[0x1020], []) + self.assertEqual(blockrefs[0x1030], []) + + def testSmdaFunctionNormalizesSerializedDalvikCfg(self): + function_dict = { + "offset": 0x1000, + "blocks": { + 0x1000: [ + [0x1000, "6e00", "invoke-static", ""], + [0x1002, "0c00", "move-result-object", ""], + [0x1004, "1100", "return-object", ""], + ], + 0x1010: [[0x1010, "0d00", "move-exception", ""], [0x1012, "1100", "return-object", ""]], + 0x1020: [[0x1020, "0d00", "move-exception", ""], [0x1022, "1100", "return-object", ""]], + }, + "apirefs": {}, + "stringrefs": {}, + "blockrefs": {}, + "inrefs": [], + "outrefs": {}, + "is_exported": False, + "architecture_metadata": { + "debug_info_off": 0, + "try_ranges": [ + { + "start_addr": 0x1000, + "end_addr": 0x1004, + "handlers": [{"type_idx": 1, "type_name": "Ljava/lang/Exception;", "target_addr": 0x1010}], + "catch_all_addr": 0x1020, + } + ], + "exception_handlers": [ + { + "type_idx": 1, + "type_name": "Ljava/lang/Exception;", + "target_addr": 0x1010, + "protected_range_start": 0x1000, + "protected_range_end": 0x1004, + } + ], + }, + "metadata": { + "binweight": 0, + "characteristics": "", + "confidence": 0.0, + "function_name": "LFoo;->bar()Ljava/lang/Object;", + "pic_hash": None, + "nesting_depth": 0, + "strongly_connected_components": [], + "tfidf": None, + }, + } + smda_function = SmdaFunction.fromDict(function_dict) + self.assertIn(0x1000, smda_function.blockrefs) + self.assertEqual(smda_function.blockrefs[0x1000], [0x1010, 0x1020]) + self.assertGreater(smda_function.nesting_depth, 0) + + def testDalvikExceptionMetadataAndNormalizedCfg(self): + functions_with_tries = [ + function + for function in self.file_disassembly.getFunctions() + if function.architecture_metadata.get("try_ranges") + ] + self.assertTrue(functions_with_tries) + function = functions_with_tries[0] + self.assertIn("debug_info_off", function.architecture_metadata) + self.assertIsInstance(function.architecture_metadata["exception_handlers"], list) + self.assertGreaterEqual(function.architecture_metadata["exception_handler_count"], 1) + self.assertIn(function.offset, function.blockrefs) + + def testReportRoundTrip(self): + report_dict = self.file_disassembly.toDict() + self.assertEqual(report_dict["status"], "ok") + self.assertEqual(report_dict["architecture"], "dalvik") + self.assertEqual(report_dict["base_addr"], 0) + self.assertEqual(report_dict["binary_size"], 247668) + self.assertEqual(report_dict["bitness"], 32) + self.assertTrue(report_dict["data_refs_from"] is not None) + self.assertGreater(len(report_dict["xcfg"]), 2000) + + reconstructed = SmdaReport.fromDict(report_dict) + self.assertEqual(reconstructed.status, "ok") + self.assertEqual(reconstructed.architecture, "dalvik") + self.assertEqual(reconstructed.base_addr, 0) + self.assertEqual(reconstructed.binary_size, 247668) + self.assertEqual(reconstructed.sha256, self.file_disassembly.sha256) + self.assertEqual(len(reconstructed.xcfg), len(self.file_disassembly.xcfg)) + + def testBufferDisassembly(self): + self.assertEqual(self.buffer_disassembly.status, "ok") + self.assertEqual(self.buffer_disassembly.architecture, "dalvik") + self.assertEqual(self.buffer_disassembly.bitness, 32) + self.assertEqual(self.buffer_disassembly.base_addr, 0) + self.assertEqual(self.buffer_disassembly.num_functions, self.file_disassembly.num_functions) + self.assertEqual(self.buffer_disassembly.num_instructions, self.file_disassembly.num_instructions) + + def testAnalyzeScriptVerboseOutputAvoidsCfgNoise(self): + result = subprocess.run( + [sys.executable, os.path.join(config.PROJECT_ROOT, "analyze.py"), "-p", "-v", self._temp_file_name], + cwd=config.PROJECT_ROOT, + capture_output=True, + text=True, + check=False, + ) + self.assertEqual(result.returncode, 0, msg=result.stderr) + combined_output = result.stdout + result.stderr + self.assertIn("dalvik.32bit", combined_output) + self.assertIn("DEX v", combined_output) + self.assertIn("heuristics=[", combined_output) + self.assertIn("api_refs=", combined_output) + self.assertNotIn("Current analysis callback time", combined_output) + self.assertNotIn("r not in G", combined_output) + + def testCodeItemHeaderParser(self): + header = struct.pack("