diff --git a/src/org/jetbrains/java/decompiler/modules/decompiler/StackVarsProcessor.java b/src/org/jetbrains/java/decompiler/modules/decompiler/StackVarsProcessor.java index 159882deea..d30a1b1ba3 100644 --- a/src/org/jetbrains/java/decompiler/modules/decompiler/StackVarsProcessor.java +++ b/src/org/jetbrains/java/decompiler/modules/decompiler/StackVarsProcessor.java @@ -6,18 +6,11 @@ import org.jetbrains.java.decompiler.main.DecompilerContext; import org.jetbrains.java.decompiler.main.extern.IFernflowerPreferences; import org.jetbrains.java.decompiler.modules.decompiler.exps.*; -import org.jetbrains.java.decompiler.modules.decompiler.flow.DirectEdge; -import org.jetbrains.java.decompiler.modules.decompiler.flow.DirectEdgeType; -import org.jetbrains.java.decompiler.modules.decompiler.flow.DirectGraph; -import org.jetbrains.java.decompiler.modules.decompiler.flow.DirectNode; -import org.jetbrains.java.decompiler.modules.decompiler.flow.DirectNodeType; -import org.jetbrains.java.decompiler.modules.decompiler.flow.FlattenStatementsHelper; import org.jetbrains.java.decompiler.modules.decompiler.exps.FunctionExprent.FunctionType; -import org.jetbrains.java.decompiler.modules.decompiler.sforms.*; -import org.jetbrains.java.decompiler.modules.decompiler.stats.CatchStatement; -import org.jetbrains.java.decompiler.modules.decompiler.stats.DoStatement; -import org.jetbrains.java.decompiler.modules.decompiler.stats.RootStatement; -import org.jetbrains.java.decompiler.modules.decompiler.stats.Statement; +import org.jetbrains.java.decompiler.modules.decompiler.flow.*; +import org.jetbrains.java.decompiler.modules.decompiler.sforms.SSAConstructorSparseEx; +import org.jetbrains.java.decompiler.modules.decompiler.sforms.SSAUConstructorSparseEx; +import org.jetbrains.java.decompiler.modules.decompiler.stats.*; import org.jetbrains.java.decompiler.modules.decompiler.vars.VarVersionNode; import org.jetbrains.java.decompiler.modules.decompiler.vars.VarVersionPair; import org.jetbrains.java.decompiler.modules.decompiler.vars.VarVersionsGraph; @@ -27,9 +20,8 @@ import org.jetbrains.java.decompiler.struct.gen.CodeType; import org.jetbrains.java.decompiler.struct.gen.MethodDescriptor; import org.jetbrains.java.decompiler.struct.gen.VarType; -import org.jetbrains.java.decompiler.util.DotExporter; -import org.jetbrains.java.decompiler.util.collections.FastSparseSetFactory.FastSparseSet; import org.jetbrains.java.decompiler.util.InterpreterUtil; +import org.jetbrains.java.decompiler.util.collections.FastSparseSetFactory.FastSparseSet; import org.jetbrains.java.decompiler.util.collections.SFormsFastMapDirect; import java.util.*; @@ -119,103 +111,129 @@ private static void setExprentVersionsToNull(Exprent exprent) { } } + record NodePath(DirectNode node, Map map, List nodes) {} + + record Node(List exprents, Statement stat) {} + private static boolean iterateStatements(RootStatement root, SSAUConstructorSparseEx ssa, StackSimplifyOptions options) { FlattenStatementsHelper flatthelper = new FlattenStatementsHelper(); DirectGraph dgraph = flatthelper.buildDirectGraph(root); - boolean res = false; + Map lvts = new HashMap<>(); - Set setVisited = new HashSet<>(); - LinkedList stack = new LinkedList<>(); - LinkedList> stackMaps = new LinkedList<>(); + boolean res = iterateStatementsNested(root, dgraph.first, ssa, options, dgraph, lvts, new HashSet<>(Set.of(dgraph.first)), new LinkedList<>()); - stack.add(dgraph.first); - stackMaps.add(new HashMap<>()); + dgraph.iterateExprentsDeep(ex -> { + if (ex instanceof VarExprent var && lvts.containsKey(var.getVarVersionPair())) { + LocalVariable lvt = lvts.get(var.getVarVersionPair()); + if (var.getLVT() == null) { + var.setLVT(lvt); + } + } - Map lvts = new HashMap<>(); + return 0; + }); - int[] ret = {0, 0}; - while (!stack.isEmpty()) { - DirectNode nd = stack.removeFirst(); - Map mapVarValues = stackMaps.removeFirst(); + return res; + } - if (setVisited.contains(nd)) { - continue; - } + private static boolean iterateStatementsNested(Statement root, DirectNode start, SSAUConstructorSparseEx ssa, StackSimplifyOptions options, DirectGraph dgraph, Map lvts, Set setVisited, List exits) { + boolean res = false; - setVisited.add(nd); + LinkedList stack = new LinkedList<>(); - List> lstLists = new ArrayList<>(); + stack.add(new NodePath(start, new HashMap<>(), new ArrayList<>())); + int[] ret = {0, 0}; + while (!stack.isEmpty()) { + NodePath path = stack.removeFirst(); + DirectNode nd = path.node; + Map mapVarValues = path.map; + + List lstLists = path.nodes; if (!nd.exprents.isEmpty()) { - lstLists.add(nd.exprents); + lstLists.add(new Node(nd.exprents, nd.statement)); } - List succs = nd.getSuccessors(DirectEdgeType.REGULAR); - if (succs.size() == 1) { - DirectNode ndsucc = succs.get(0).getDestination(); - - if (ndsucc.type == DirectNodeType.TAIL && !ndsucc.exprents.isEmpty()) { - lstLists.add(succs.get(0).getDestination().exprents); - nd = ndsucc; + List succs = nd.getSuccessors(DirectEdgeType.REGULAR).stream().map(DirectEdge::getDestination).toList(); + List newSuccs = new ArrayList<>(); + + for (DirectNode succ : succs) { + if (!setVisited.add(succ)) { + // Skip seen nodes + } else if (succ.statement instanceof SwitchStatement switchSt && switchSt.isPhantom() && root != switchSt) { + // Treat phantom switch statements as a sub method + List subExits = new ArrayList<>(); + res |= iterateStatementsNested(switchSt, succ, ssa, options, dgraph, lvts, setVisited, subExits); + newSuccs.addAll(subExits); + } else if (!root.containsStatement(succ.statement) && !(root instanceof RootStatement)) { + exits.add(succ); + } else { + newSuccs.add(succ); } } - // To handle stacks created by some duplicated bytecode (dup, dup_x2, etc.) better, we run the simplification algorithm in 2 passes. - // The first pass is the classic algorithm, and the second pass is a more aggressive one that allows some slightly unsafe operations, such as simplifying across variables. - // To ensure the second pass doesn't break good bytecode, if the first pass manages to update any exprent it will cancel the second pass. - // This behavior can be turned off with a fernflower preference. - for (int stackStage = 0; stackStage < 2; stackStage++) { - // If instructed to not use the second pass, set it to 2 here to prevent the loop from working - if (!DecompilerContext.getOption(IFernflowerPreferences.SIMPLIFY_STACK_SECOND_PASS)) { - stackStage = 2; - } + succs = newSuccs; - for (int i = 0; i < lstLists.size(); i++) { - List lst = lstLists.get(i); + if (succs.size() == 1) { + stack.add(new NodePath(succs.get(0), new HashMap<>(mapVarValues), new ArrayList<>(lstLists))); + } else { + // To handle stacks created by some duplicated bytecode (dup, dup_x2, etc.) better, we run the simplification algorithm in 2 passes. + // The first pass is the classic algorithm, and the second pass is a more aggressive one that allows some slightly unsafe operations, such as simplifying across variables. + // To ensure the second pass doesn't break good bytecode, if the first pass manages to update any exprent it will cancel the second pass. + // This behavior can be turned off with a fernflower preference. + for (int stackStage = 0; stackStage < 2; stackStage++) { + // If instructed to not use the second pass, set it to 2 here to prevent the loop from working + if (!DecompilerContext.getOption(IFernflowerPreferences.SIMPLIFY_STACK_SECOND_PASS)) { + stackStage = 2; + } - int index = 0; - while (index < lst.size()) { - Exprent next = null; + for (int i = 0; i < lstLists.size(); i++) { + Node node = lstLists.get(i); + List lst = node.exprents; - if (index == lst.size() - 1) { - if (i < lstLists.size() - 1) { - next = lstLists.get(i + 1).get(0); + int index = 0; + while (index < lst.size()) { + Exprent next = null; + + if (index == lst.size() - 1) { + if (i < lstLists.size() - 1) { + next = lstLists.get(i + 1).exprents.get(0); + } + } else { + next = lst.get(index + 1); } - } else { - next = lst.get(index + 1); - } - boolean simplifyAcrossStack = stackStage == 1; + boolean simplifyAcrossStack = stackStage == 1; - // {newIndex, changed} - iterateExprent(lst, nd.statement, lvts, index, next, mapVarValues, ssa, simplifyAcrossStack, ret, options); + // {newIndex, changed} + iterateExprent(lst, node.stat, lvts, index, next, mapVarValues, ssa, simplifyAcrossStack, ret, options); - // If index is specified, set to that - if (ret[0] >= 0) { - index = ret[0]; - } else { - // Otherwise, continue to next index - index++; - } + // If index is specified, set to that + if (ret[0] >= 0) { + index = ret[0]; + } else { + // Otherwise, continue to next index + index++; + } - // Mark if we changed - boolean changed = ret[1] == 1; - res |= changed; + // Mark if we changed + boolean changed = ret[1] == 1; + res |= changed; - // We only want to simplify across stack bounds if we were not able to change *anything* - if (changed) { - // Cancel the second pass by setting the stage to 2, preventing the check next time it runs - stackStage = 2; + // We only want to simplify across stack bounds if we were not able to change *anything* + if (changed) { + // Cancel the second pass by setting the stage to 2, preventing the check next time it runs + stackStage = 2; + } + // An (unintentional) side effect of this implementation is that as soon as the second pass is able to change the stack, it'll cancel further iteration of the second pass, preventing it from creating wrong code by accident. } - // An (unintentional) side effect of this implementation is that as soon as the second pass is able to change the stack, it'll cancel further iteration of the second pass, preventing it from creating wrong code by accident. } } - } - for (DirectEdge ndx : succs) { - stack.add(ndx.getDestination()); - stackMaps.add(new HashMap<>(mapVarValues)); + for (DirectNode ndx : succs) { + stack.add(new NodePath(ndx, new HashMap<>(mapVarValues), new ArrayList<>())); + } } // make sure the 3 special exprent lists in a loop (init, condition, increment) are not empty @@ -235,18 +253,6 @@ private static boolean iterateStatements(RootStatement root, SSAUConstructorSpar } } } - - dgraph.iterateExprentsDeep(ex -> { - if (ex instanceof VarExprent var && lvts.containsKey(var.getVarVersionPair())) { - LocalVariable lvt = lvts.get(var.getVarVersionPair()); - if (var.getLVT() == null) { - var.setLVT(lvt); - } - } - - return 0; - }); - return res; } @@ -316,7 +322,11 @@ private static void iterateExprent(List lstExprents, int changed = 0; Object[] arr = {null, false, false}; - for (Exprent expr : exprent.getAllExprents()) { + List containing = exprent.getAllExprents(); + if (exprent instanceof SwitchExprent switchExp) { + containing = switchExp.getBacking().getHeadexprentList(); + } + for (Exprent expr : containing) { while (true) { iterateChildExprent(expr, exprent, next, mapVarValues, ssau, arr, options); Exprent retexpr = (Exprent)arr[0]; @@ -733,7 +743,7 @@ private static boolean areTreesEqual(VarExprent left, List treeA, List< private static Set getAllVersions(Exprent exprent) { Set res = new HashSet<>(); - List exprents = exprent.getAllExprents(true); + List exprents = getAllExprents(exprent); exprents.add(exprent); for (Exprent expr : exprents) { @@ -758,7 +768,11 @@ private static void iterateChildExprent(Exprent exprent, boolean changed = false; Object[] arr = {null, false, false}; - for (Exprent expr : exprent.getAllExprents()) { + List containing = exprent.getAllExprents(); + if (exprent instanceof SwitchExprent switchExp) { + containing = switchExp.getBacking().getHeadexprentList(); + } + for (Exprent expr : containing) { while (true) { iterateChildExprent(expr, parent, next, mapVarValues, ssau, arr, options); Exprent retexpr = (Exprent)arr[0]; @@ -810,7 +824,7 @@ private static void iterateChildExprent(Exprent exprent, } boolean isHeadSynchronized = false; - if (next == null && parent instanceof MonitorExprent) { + if (parent instanceof MonitorExprent) { MonitorExprent monexpr = (MonitorExprent)parent; if (monexpr.getMonType() == MonitorExprent.Type.ENTER && exprent.equals(monexpr.getValue())) { isHeadSynchronized = true; @@ -1003,7 +1017,7 @@ private static Map> getAllVarVersions(VarVersionPai Map> map = new HashMap<>(); SFormsFastMapDirect mapLiveVars = ssau.getLiveVarVersionsMap(leftvar); - List lst = exprent.getAllExprents(true); + List lst = getAllExprents(exprent); lst.add(exprent); for (Exprent expr : lst) { @@ -1209,6 +1223,16 @@ private static List getRoots(VarVersionNode vvnode) { return ret; } + private static List getAllExprents(Exprent exprent) { + List exprents = exprent.getAllExprents(true); + for (int i = 0; i < exprents.size(); i++) { + if (exprents.get(i) instanceof SwitchExprent switchExp) { + exprents.addAll(getAllExprents(switchExp.getBacking().getHeadexprent())); + } + } + return exprents; + } + public static class StackSimplifyOptions { private boolean inlineRegularVars = false; public StackSimplifyOptions() { diff --git a/test/org/jetbrains/java/decompiler/SingleClassesTest.java b/test/org/jetbrains/java/decompiler/SingleClassesTest.java index 2d27082823..0532495faf 100644 --- a/test/org/jetbrains/java/decompiler/SingleClassesTest.java +++ b/test/org/jetbrains/java/decompiler/SingleClassesTest.java @@ -831,6 +831,8 @@ private void registerDefault() { register(JAVA_17, "TestStringSwitchTypes"); register(JAVA_16, "TestSwitchExpressionIfBlocks"); + + register(JAVA_16, "TestSwitchExpressionMultiple"); } private void registerEntireClassPath() { diff --git a/testData/results/pkg/TestSwitchExpressionMultiple.dec b/testData/results/pkg/TestSwitchExpressionMultiple.dec new file mode 100644 index 0000000000..98d46e341d --- /dev/null +++ b/testData/results/pkg/TestSwitchExpressionMultiple.dec @@ -0,0 +1,171 @@ +package pkg; + +public class TestSwitchExpressionMultiple { + public int test(int i1, int i2) { + return switch (switch (i2) {// 5 + case 0 -> 1;// 6 + case 1 -> 0;// 7 + default -> 0;// 8 + }) { + case 0 -> 1;// 10 + case 1 -> 0;// 11 + default -> 0;// 12 + } + switch (i1) { + case 0 -> { + switch (i2) {// 14 + case 0: + yield 1;// 15 + case 1: + yield 0;// 16 + default: + yield 0;// 17 + } + } + case 1 -> 0;// 19 + default -> 0;// 20 + }; + } +} + +class 'pkg/TestSwitchExpressionMultiple' { + method 'test (II)I' { + 0 4 + 1 4 + 2 4 + 3 4 + 4 4 + 5 4 + 6 4 + 7 4 + 8 4 + 9 4 + a 4 + b 4 + c 4 + d 4 + e 4 + f 4 + 10 4 + 11 4 + 12 4 + 13 4 + 14 4 + 15 4 + 16 4 + 17 4 + 18 4 + 19 4 + 1a 4 + 1b 4 + 1c 5 + 20 6 + 24 7 + 25 4 + 26 4 + 27 4 + 28 4 + 29 4 + 2a 4 + 2b 4 + 2c 4 + 2d 4 + 2e 4 + 2f 4 + 30 4 + 31 4 + 32 4 + 33 4 + 34 4 + 35 4 + 36 4 + 37 4 + 38 4 + 39 4 + 3a 4 + 3b 4 + 3c 4 + 3d 4 + 3e 4 + 3f 4 + 40 9 + 44 10 + 48 11 + 49 12 + 4a 12 + 4b 12 + 4c 12 + 4d 12 + 4e 12 + 4f 12 + 50 12 + 51 12 + 52 12 + 53 12 + 54 12 + 55 12 + 56 12 + 57 12 + 58 12 + 59 12 + 5a 12 + 5b 12 + 5c 12 + 5d 12 + 5e 12 + 5f 12 + 60 12 + 61 12 + 62 12 + 63 12 + 64 14 + 65 14 + 66 14 + 67 14 + 68 14 + 69 14 + 6a 14 + 6b 14 + 6c 14 + 6d 14 + 6e 14 + 6f 14 + 70 14 + 71 14 + 72 14 + 73 14 + 74 14 + 75 14 + 76 14 + 77 14 + 78 14 + 79 14 + 7a 14 + 7b 14 + 7c 14 + 7d 14 + 7e 14 + 7f 14 + 80 16 + 84 18 + 88 20 + 8c 23 + 90 24 + 91 4 + 92 4 + } +} + +Lines mapping: +5 <-> 5 +6 <-> 6 +7 <-> 7 +8 <-> 8 +10 <-> 10 +11 <-> 11 +12 <-> 12 +14 <-> 15 +15 <-> 17 +16 <-> 19 +17 <-> 21 +19 <-> 24 +20 <-> 25 diff --git a/testData/results/pkg/TestSynchronized.dec b/testData/results/pkg/TestSynchronized.dec index 03593782b8..2a35086732 100644 --- a/testData/results/pkg/TestSynchronized.dec +++ b/testData/results/pkg/TestSynchronized.dec @@ -609,4 +609,4 @@ Not mapped: 83 88 93 -134 \ No newline at end of file +134 diff --git a/testData/src/java16/pkg/TestSwitchExpressionMultiple.java b/testData/src/java16/pkg/TestSwitchExpressionMultiple.java new file mode 100644 index 0000000000..d9b0e58730 --- /dev/null +++ b/testData/src/java16/pkg/TestSwitchExpressionMultiple.java @@ -0,0 +1,23 @@ +package pkg; + +public class TestSwitchExpressionMultiple { + public int test(int i1, int i2) { + return switch (switch (i2) { + case 0 -> 1; + case 1 -> 0; + default -> 0; + }) { + case 0 -> 1; + case 1 -> 0; + default -> 0; + } + switch (i1) { + case 0 -> switch (i2) { + case 0 -> 1; + case 1 -> 0; + default -> 0; + }; + case 1 -> 0; + default -> 0; + }; + } +}