@@ -241,7 +241,7 @@ impl AnalyzerPass for GraphJoinInference {
241241 // Empty joins vector = fully denormalized pattern (no JOINs needed)
242242 // Without this wrapper, RenderPlan will try to generate JOINs from raw GraphRel
243243 let optional_aliases = plan_ctx. get_optional_aliases ( ) . clone ( ) ;
244- Self :: build_graph_joins ( logical_plan, & mut collected_graph_joins, optional_aliases, plan_ctx)
244+ Self :: build_graph_joins ( logical_plan, & mut collected_graph_joins, optional_aliases, plan_ctx, graph_schema )
245245 }
246246}
247247
@@ -775,19 +775,40 @@ impl GraphJoinInference {
775775 collected_graph_joins : & mut Vec < Join > ,
776776 optional_aliases : std:: collections:: HashSet < String > ,
777777 plan_ctx : & PlanCtx ,
778+ graph_schema : & GraphSchema ,
778779 ) -> AnalyzerResult < Transformed < Arc < LogicalPlan > > > {
779780 let transformed_plan = match logical_plan. as_ref ( ) {
780- // If input is a Union, push GraphJoins into each branch
781+ // If input is a Union, process each branch INDEPENDENTLY
782+ // Each branch needs its own collect_graph_joins + build_graph_joins pass
781783 LogicalPlan :: Union ( union) => {
782- log:: info!( "🔄 Union detected in build_graph_joins, processing {} branches" , union . inputs. len( ) ) ;
784+ log:: info!( "🔄 Union detected in build_graph_joins, processing {} branches independently " , union . inputs. len( ) ) ;
783785 let mut any_transformed = false ;
786+ let graph_join_inference = GraphJoinInference :: new ( ) ;
787+
784788 let transformed_branches: Result < Vec < Arc < LogicalPlan > > , _ > = union. inputs . iter ( ) . map ( |branch| {
785- let mut branch_joins = collected_graph_joins. clone ( ) ;
789+ // CRITICAL: Each branch needs fresh state - collect and build separately
790+ let mut branch_joins: Vec < Join > = vec ! [ ] ;
791+ let mut branch_joined_entities: HashSet < String > = HashSet :: new ( ) ;
792+
793+ // Collect joins for this specific branch only
794+ graph_join_inference. collect_graph_joins (
795+ branch. clone ( ) ,
796+ branch. clone ( ) ,
797+ & mut plan_ctx. clone ( ) , // Clone PlanCtx for each branch
798+ graph_schema,
799+ & mut branch_joins,
800+ & mut branch_joined_entities,
801+ ) ?;
802+
803+ eprintln ! ( "🔹 Union branch collected {} joins" , branch_joins. len( ) ) ;
804+
805+ // Build GraphJoins for this branch with its own collected joins
786806 let result = Self :: build_graph_joins (
787807 branch. clone ( ) ,
788808 & mut branch_joins,
789809 optional_aliases. clone ( ) ,
790810 plan_ctx,
811+ graph_schema,
791812 ) ?;
792813 if matches ! ( result, Transformed :: Yes ( _) ) {
793814 any_transformed = true ;
@@ -834,6 +855,7 @@ impl GraphJoinInference {
834855 collected_graph_joins,
835856 optional_aliases. clone ( ) ,
836857 plan_ctx,
858+ graph_schema,
837859 ) ?;
838860
839861 // is_denormalized flag is set by view_optimizer pass - just rebuild
@@ -845,18 +867,21 @@ impl GraphJoinInference {
845867 collected_graph_joins,
846868 optional_aliases. clone ( ) ,
847869 plan_ctx,
870+ graph_schema,
848871 ) ?;
849872 let center_tf = Self :: build_graph_joins (
850873 graph_rel. center . clone ( ) ,
851874 collected_graph_joins,
852875 optional_aliases. clone ( ) ,
853876 plan_ctx,
877+ graph_schema,
854878 ) ?;
855879 let right_tf = Self :: build_graph_joins (
856880 graph_rel. right . clone ( ) ,
857881 collected_graph_joins,
858882 optional_aliases. clone ( ) ,
859883 plan_ctx,
884+ graph_schema,
860885 ) ?;
861886
862887 graph_rel. rebuild_or_clone ( left_tf, center_tf, right_tf, logical_plan. clone ( ) )
@@ -867,6 +892,7 @@ impl GraphJoinInference {
867892 collected_graph_joins,
868893 optional_aliases,
869894 plan_ctx,
895+ graph_schema,
870896 ) ?;
871897 cte. rebuild_or_clone ( child_tf, logical_plan. clone ( ) )
872898 }
@@ -878,6 +904,7 @@ impl GraphJoinInference {
878904 collected_graph_joins,
879905 optional_aliases,
880906 plan_ctx,
907+ graph_schema,
881908 ) ?;
882909 graph_joins. rebuild_or_clone ( child_tf, logical_plan. clone ( ) )
883910 }
@@ -887,6 +914,7 @@ impl GraphJoinInference {
887914 collected_graph_joins,
888915 optional_aliases,
889916 plan_ctx,
917+ graph_schema,
890918 ) ?;
891919 filter. rebuild_or_clone ( child_tf, logical_plan. clone ( ) )
892920 }
@@ -896,6 +924,7 @@ impl GraphJoinInference {
896924 collected_graph_joins,
897925 optional_aliases,
898926 plan_ctx,
927+ graph_schema,
899928 ) ?;
900929 group_by. rebuild_or_clone ( child_tf, logical_plan. clone ( ) )
901930 }
@@ -905,6 +934,7 @@ impl GraphJoinInference {
905934 collected_graph_joins,
906935 optional_aliases,
907936 plan_ctx,
937+ graph_schema,
908938 ) ?;
909939 order_by. rebuild_or_clone ( child_tf, logical_plan. clone ( ) )
910940 }
@@ -914,6 +944,7 @@ impl GraphJoinInference {
914944 collected_graph_joins,
915945 optional_aliases,
916946 plan_ctx,
947+ graph_schema,
917948 ) ?;
918949 skip. rebuild_or_clone ( child_tf, logical_plan. clone ( ) )
919950 }
@@ -923,6 +954,7 @@ impl GraphJoinInference {
923954 collected_graph_joins,
924955 optional_aliases,
925956 plan_ctx,
957+ graph_schema,
926958 ) ?;
927959 limit. rebuild_or_clone ( child_tf, logical_plan. clone ( ) )
928960 }
@@ -934,6 +966,7 @@ impl GraphJoinInference {
934966 collected_graph_joins,
935967 optional_aliases. clone ( ) ,
936968 plan_ctx,
969+ graph_schema,
937970 ) ?;
938971 inputs_tf. push ( child_tf) ;
939972 }
@@ -947,6 +980,7 @@ impl GraphJoinInference {
947980 collected_graph_joins,
948981 optional_aliases,
949982 plan_ctx,
983+ graph_schema,
950984 ) ?;
951985 match child_tf {
952986 Transformed :: Yes ( new_input) => Transformed :: Yes ( Arc :: new ( LogicalPlan :: Unwind ( crate :: query_planner:: logical_plan:: Unwind {
@@ -1153,17 +1187,11 @@ impl GraphJoinInference {
11531187 )
11541188 }
11551189 LogicalPlan :: Union ( union) => {
1156- eprintln ! ( "� ? Union, recursing into {} inputs" , union . inputs. len( ) ) ;
1157- for input_plan in union. inputs . iter ( ) {
1158- self . collect_graph_joins (
1159- input_plan. clone ( ) ,
1160- root_plan. clone ( ) ,
1161- plan_ctx,
1162- graph_schema,
1163- collected_graph_joins,
1164- joined_entities,
1165- ) ?;
1166- }
1190+ // CRITICAL: Don't recurse into UNION branches here!
1191+ // Each branch will be processed independently by build_graph_joins,
1192+ // which properly clones the state for each branch.
1193+ // If we recurse here with shared state, branches pollute each other.
1194+ eprintln ! ( "🔀 Union detected in collect_graph_joins - skipping recursion (handled by build_graph_joins)" ) ;
11671195 Ok ( ( ) )
11681196 }
11691197 LogicalPlan :: PageRank ( _) => {
@@ -1529,12 +1557,18 @@ impl GraphJoinInference {
15291557 to_id : "to_node_id" . to_string ( ) ,
15301558 } ,
15311559 ) ;
1532- let rel_from_col = rel_cols. from_id ;
1533- let rel_to_col = rel_cols. to_id ;
1560+
1561+ // For Direction::Incoming (from BidirectionalUnion), swap the columns
1562+ // so that the "from" side of the relationship connects to the "to" node
1563+ let ( rel_from_col, rel_to_col) = if graph_rel. direction == Direction :: Incoming {
1564+ ( rel_cols. to_id , rel_cols. from_id ) // Swapped for incoming direction
1565+ } else {
1566+ ( rel_cols. from_id , rel_cols. to_id ) // Normal for outgoing/either
1567+ } ;
15341568
15351569 eprintln ! (
1536- " � ?? DEBUG REL COLUMNS: rel_from_col = '{}', rel_to_col = '{}'" ,
1537- rel_from_col, rel_to_col
1570+ " 🔹 DEBUG REL COLUMNS: direction={:?}, rel_from_col = '{}', rel_to_col = '{}'" ,
1571+ graph_rel . direction , rel_from_col, rel_to_col
15381572 ) ;
15391573
15401574 // If both nodes are of the same type then check the direction to determine where are the left and right nodes present in the edgelist.
@@ -1923,6 +1957,7 @@ impl GraphJoinInference {
19231957 // DON'T mark as joined - denormalized nodes are virtual, not physical tables
19241958 } else {
19251959 // Traditional: Join LEFT node first
1960+ eprintln ! ( " 🔹 CREATING LEFT JOIN: u1 ON r.{}" , rel_from_col) ;
19261961 let left_graph_join = Join {
19271962 table_name : left_cte_name. clone ( ) ,
19281963 table_alias : left_alias. to_string ( ) ,
@@ -1944,7 +1979,7 @@ impl GraphJoinInference {
19441979 } ;
19451980 collected_graph_joins. push ( left_graph_join) ;
19461981 joined_entities. insert ( left_alias. to_string ( ) ) ;
1947- eprintln ! ( " � ? LEFT node '{}' joined first" , left_alias) ;
1982+ eprintln ! ( " ✓ LEFT node '{}' joined first" , left_alias) ;
19481983 }
19491984 }
19501985
@@ -3623,6 +3658,9 @@ mod tests {
36233658 assert_eq ! ( rel_join_condition. operands. len( ) , 2 ) ;
36243659
36253660 // For incoming direction, the relationship connects differently
3661+ // Pattern (p2)<-[f1]-(p1) means p1 FOLLOWS p2, so:
3662+ // - f1.from_id = p1.id (source)
3663+ // - f1.to_id = p2.id (target) ← p2 is the anchor, connects via to_id
36263664 match (
36273665 & rel_join_condition. operands [ 0 ] ,
36283666 & rel_join_condition. operands [ 1 ] ,
@@ -3632,7 +3670,7 @@ mod tests {
36323670 LogicalExpr :: PropertyAccessExp ( right_prop) ,
36333671 ) => {
36343672 assert_eq ! ( rel_prop. table_alias. 0 , "f1" ) ;
3635- assert_eq ! ( rel_prop. column. raw( ) , "from_id " ) ;
3673+ assert_eq ! ( rel_prop. column. raw( ) , "to_id " ) ; // p2 is target, connects via to_id
36363674 assert_eq ! ( right_prop. table_alias. 0 , "p2" ) ;
36373675 assert_eq ! ( right_prop. column. raw( ) , "id" ) ;
36383676 }
@@ -3659,7 +3697,7 @@ mod tests {
36593697 assert_eq ! ( p1_prop. table_alias. 0 , "p1" ) ;
36603698 assert_eq ! ( p1_prop. column. raw( ) , "id" ) ;
36613699 assert_eq ! ( rel_prop. table_alias. 0 , "f1" ) ;
3662- assert_eq ! ( rel_prop. column. raw( ) , "to_id " ) ;
3700+ assert_eq ! ( rel_prop. column. raw( ) , "from_id " ) ; // p1 is source, connects via from_id
36633701 }
36643702 _ => panic ! ( "Expected PropertyAccessExp operands for p1 join" ) ,
36653703 }
0 commit comments