Skip to content

Commit 7f86af4

Browse files
viktorcosclaude
andcommitted
Port COS NestedSelectClauseRewriter fix to NHibernate 5.5.3
Replace the flat SelectClauseRewriter approach in NestedSelectRewriter.ReWrite() with a new recursive NestedSelectClauseRewriter class, ported from the COS Systems NHibernateAll 4.0.4.4002 custom build. This fixes support for nested collection expands (e.g. OData $expand=Orders/OrderLines) where the LINQ provider previously generated stray dots in SQL and failed to populate second-level collections. Also update global.json rollForward to latestMajor to support building with .NET SDK 9+. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d877aad commit 7f86af4

5 files changed

Lines changed: 331 additions & 48 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ TestResult.xml
1818
.idea/
1919
.vs/
2020
/build-common/NHibernate.dev.props
21+
/nupkg/
2122
/doc/reference/master.xml
2223
/doc/bin/
2324
/doc/obj/
2425
/Tools/bin/
25-
/Tools/obj/
26+
/Tools/obj/

build-common/NHibernate.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<NhVersion Condition="'$(NhVersion)' == ''" >5.6</NhVersion>
66
<VersionPatch Condition="'$(VersionPatch)' == ''">0</VersionPatch>
77
<!-- Clear VersionSuffix for making release and set it to dev for making development builds -->
8-
<VersionSuffix Condition="'$(VersionSuffix)' == ''"></VersionSuffix>
8+
<VersionSuffix Condition="'$(VersionSuffix)' == ''">cos.1</VersionSuffix>
99
<LangVersion Condition="'$(MSBuildProjectExtension)' != '.vbproj'">12.0</LangVersion>
1010

1111
<VersionPrefix Condition="'$(VersionPrefix)' == ''">$(NhVersion).$(VersionPatch)</VersionPrefix>

global.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"sdk": {
33
"version": "8.0.100",
4-
"rollForward": "latestFeature"
4+
"rollForward": "latestMajor"
55
}
66
}
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
using System;
2+
using System.Collections;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Linq.Expressions;
6+
using System.Reflection;
7+
using NHibernate.Linq.Clauses;
8+
using NHibernate.Linq.GroupBy;
9+
using NHibernate.Linq.Visitors;
10+
using NHibernate.Util;
11+
using Remotion.Linq;
12+
using Remotion.Linq.Clauses;
13+
using Remotion.Linq.Clauses.Expressions;
14+
using Remotion.Linq.Parsing;
15+
16+
namespace NHibernate.Linq.NestedSelects
17+
{
18+
/// <summary>
19+
/// Recursive visitor that rewrites nested collection expands (e.g. OData $expand=Orders/OrderLines)
20+
/// into a chain of GroupBy/Select expressions that NHibernate can translate to SQL.
21+
/// Ported from the COS Systems NHibernateAll 4.0.4.4002 custom build.
22+
/// </summary>
23+
internal class NestedSelectClauseRewriter : RelinqExpressionVisitor
24+
{
25+
private static readonly MethodInfo CastMethod =
26+
ReflectHelper.FastGetMethod(Enumerable.Cast<object[]>, default(IEnumerable));
27+
28+
private static readonly MethodInfo GroupByMethod =
29+
ReflectHelper.FastGetMethod(Enumerable.GroupBy<object[], object>,
30+
default(IEnumerable<object[]>), default(Func<object[], object>));
31+
32+
private static readonly MethodInfo SingleMethod =
33+
ReflectHelper.FastGetMethod(Enumerable.Single<IGrouping<object, object[]>>,
34+
default(IEnumerable<IGrouping<object, object[]>>));
35+
36+
private static readonly MethodInfo WhereMethod =
37+
ReflectHelper.FastGetMethod(Enumerable.Where<IGrouping<object, object[]>>,
38+
default(IEnumerable<IGrouping<object, object[]>>), default(Func<IGrouping<object, object[]>, bool>));
39+
40+
private readonly ISessionFactory _sessionFactory;
41+
private readonly List<Expression> _expressions;
42+
43+
public NestedSelectClauseRewriter Parent { get; set; }
44+
public ParameterExpression Source { get; set; }
45+
public QueryModel QueryModel { get; set; }
46+
47+
public List<Expression> Expressions
48+
{
49+
get
50+
{
51+
if (Parent == null)
52+
return _expressions;
53+
return Parent.Expressions;
54+
}
55+
}
56+
57+
public int Level
58+
{
59+
get
60+
{
61+
if (Parent != null)
62+
return 1 + Parent.Level;
63+
return 0;
64+
}
65+
}
66+
67+
private static MethodInfo SelectMethod(System.Type type)
68+
{
69+
return ReflectionCache.EnumerableMethods.SelectDefinition
70+
.MakeGenericMethod(typeof(IGrouping<object, object[]>), type);
71+
}
72+
73+
private static MethodInfo ToListMethod(System.Type type)
74+
{
75+
return ReflectionCache.EnumerableMethods.ToListDefinition.MakeGenericMethod(type);
76+
}
77+
78+
private static MethodInfo ToArrayMethod(System.Type type)
79+
{
80+
return ReflectionCache.EnumerableMethods.ToArrayDefinition.MakeGenericMethod(type);
81+
}
82+
83+
private static LambdaExpression NullFilterPredicate()
84+
{
85+
var t = Expression.Parameter(typeof(IGrouping<object, object[]>), "t");
86+
return Expression.Lambda(
87+
Expression.Not(
88+
Expression.Call(typeof(object), "ReferenceEquals", System.Type.EmptyTypes,
89+
GetKeyProperty(t),
90+
Expression.Constant(null))),
91+
t);
92+
}
93+
94+
public NestedSelectClauseRewriter(ISessionFactory sessionFactory, QueryModel queryModel,
95+
NestedSelectClauseRewriter parent = null)
96+
{
97+
_sessionFactory = sessionFactory;
98+
QueryModel = queryModel;
99+
Parent = parent;
100+
101+
if (parent == null)
102+
{
103+
_expressions = new List<Expression>();
104+
Expressions.Add(new QuerySourceReferenceExpression(queryModel.MainFromClause));
105+
}
106+
else
107+
{
108+
var mainFromClause = queryModel.MainFromClause;
109+
var restrictions = queryModel.BodyClauses.OfType<WhereClause>()
110+
.Select(w => new NhWithClause(w.Predicate));
111+
var join = new NhJoinClause(mainFromClause.ItemName, mainFromClause.ItemType,
112+
mainFromClause.FromExpression, restrictions);
113+
var root = GetRoot();
114+
root.BodyClauses.Add(join);
115+
var swapVisitor = new SwapQuerySourceVisitor(mainFromClause, join);
116+
root.TransformExpressions(swapVisitor.Swap);
117+
Expressions.Add(new QuerySourceReferenceExpression(join));
118+
}
119+
120+
Source = Expression.Parameter(typeof(IGrouping<object, object[]>), $"level{Level}");
121+
}
122+
123+
protected QueryModel GetRoot()
124+
{
125+
if (Parent == null)
126+
return QueryModel;
127+
return Parent.GetRoot();
128+
}
129+
130+
public Expression Start()
131+
{
132+
int index = Expressions.Count - 1;
133+
134+
if (Parent == null)
135+
{
136+
var expression = Visit(QueryModel.SelectClause.Selector);
137+
var input = Expression.Parameter(typeof(IEnumerable<object>), "input");
138+
var selector = Expression.Lambda(expression, Source);
139+
var body = Select(selector, GetGroupBy(input, index), expression.Type);
140+
return Expression.Lambda(body, input);
141+
}
142+
143+
var expression2 = Visit(QueryModel.SelectClause.Selector);
144+
return Select(Expression.Lambda(expression2, Source), GetGroupBy(Parent.Source, index), expression2.Type);
145+
}
146+
147+
private MethodCallExpression Select(Expression selector, Expression input, System.Type elementType)
148+
{
149+
return CallAsReadOnly(elementType, CallNullFilteredSelect(selector, input, elementType));
150+
}
151+
152+
private static MethodCallExpression CallNullFilteredSelect(Expression selector, Expression input, System.Type elementType)
153+
{
154+
return Expression.Call(SelectMethod(elementType),
155+
Expression.Call(WhereMethod, input, NullFilterPredicate()),
156+
selector);
157+
}
158+
159+
private Expression SelectCollection(Expression selector, Expression input, System.Type collectionType, System.Type elementType)
160+
{
161+
var filtered = CallNullFilteredSelect(selector, input, elementType);
162+
if (collectionType.IsArray)
163+
return Expression.Call(ToArrayMethod(elementType), filtered);
164+
var constructor = GetCollectionConstructor(collectionType, elementType);
165+
if (constructor != null)
166+
return Expression.New(constructor, filtered);
167+
return CallAsReadOnly(elementType, filtered);
168+
}
169+
170+
private static MethodCallExpression CallAsReadOnly(System.Type elementType, MethodCallExpression select)
171+
{
172+
return Expression.Call(
173+
Expression.Call(ToListMethod(elementType), select),
174+
"AsReadOnly",
175+
System.Type.EmptyTypes);
176+
}
177+
178+
private static MethodCallExpression GetGroupBy(Expression input, int index)
179+
{
180+
return Expression.Call(GroupByMethod,
181+
Expression.Call(CastMethod, input),
182+
GetArrayIndex(index));
183+
}
184+
185+
private static LambdaExpression GetArrayIndex(int index)
186+
{
187+
var g = Expression.Parameter(typeof(object[]), "g");
188+
return Expression.Lambda(Expression.ArrayIndex(g, Expression.Constant(index)), g);
189+
}
190+
191+
protected override Expression VisitSubQuery(SubQueryExpression expression)
192+
{
193+
var child = new NestedSelectClauseRewriter(_sessionFactory, expression.QueryModel, this);
194+
var result = child.Start();
195+
foreach (var expr in child.Expressions)
196+
AddIfMissing(expr);
197+
return result;
198+
}
199+
200+
private static ConstructorInfo GetCollectionConstructor(System.Type collectionType, System.Type elementType)
201+
{
202+
if (collectionType.IsInterface)
203+
{
204+
if (collectionType.IsGenericType && collectionType.GetGenericTypeDefinition() == typeof(ISet<>))
205+
return typeof(HashSet<>).MakeGenericType(elementType)
206+
.GetConstructor(new[] {typeof(IEnumerable<>).MakeGenericType(elementType)});
207+
return null;
208+
}
209+
210+
return collectionType.GetConstructor(new[] {typeof(IEnumerable<>).MakeGenericType(elementType)});
211+
}
212+
213+
protected override Expression VisitMember(MemberExpression expression)
214+
{
215+
if (expression.Expression is QuerySourceReferenceExpression)
216+
{
217+
var memberType = ReflectHelper.GetPropertyOrFieldType(expression.Member);
218+
if (memberType != null && memberType.IsCollectionType()
219+
&& IsMappedCollection(expression.Member) && memberType.IsGenericType)
220+
{
221+
var elementType = memberType.GetGenericArguments()[0];
222+
int index = JoinAndSaveExpression(expression, GetElementType(expression.Type));
223+
var s = Expression.Parameter(typeof(IGrouping<object, object[]>), "s");
224+
var selector = Expression.Lambda(Expression.Convert(GetKeyProperty(s), elementType), s);
225+
return SelectCollection(selector, GetGroupBy(Source, index), memberType, elementType);
226+
}
227+
228+
if (memberType != null && IsMapped(expression.Type))
229+
{
230+
int index = JoinAndSaveExpression(expression, expression.Type);
231+
return Expression.Convert(
232+
GetKeyProperty(Expression.Call(SingleMethod, GetGroupBy(Source, index))),
233+
memberType);
234+
}
235+
}
236+
else if (expression.Expression is MemberExpression parentMember)
237+
{
238+
var memberType = ReflectHelper.GetPropertyOrFieldType(expression.Member);
239+
if (memberType != null && memberType.IsCollectionType() && IsMappedCollection(expression.Member))
240+
{
241+
int index = JoinAndSaveExpression(expression, GetElementType(expression.Type));
242+
var groupByParent = GetGroupByParent(parentMember, Source,
243+
FindParentIndex(index - 1, parentMember.Type));
244+
var elementType = memberType.GetGenericArguments()[0];
245+
var s = Expression.Parameter(typeof(IGrouping<object, object[]>), "s");
246+
var selector = Expression.Lambda(Expression.Convert(GetKeyProperty(s), elementType), s);
247+
return SelectCollection(selector, GetGroupBy(groupByParent, index), memberType, elementType);
248+
}
249+
250+
if (memberType != null && IsMapped(expression.Type))
251+
{
252+
int index = JoinAndSaveExpression(expression, expression.Type);
253+
var groupByParent = GetGroupByParent(parentMember, Source,
254+
FindParentIndex(index - 1, parentMember.Type));
255+
return Expression.Convert(
256+
GetKeyProperty(Expression.Call(SingleMethod, GetGroupBy(groupByParent, index))),
257+
memberType);
258+
}
259+
}
260+
261+
return base.VisitMember(expression);
262+
}
263+
264+
private int FindParentIndex(int startIndex, System.Type parentType)
265+
{
266+
if (Expressions[startIndex].Type == parentType)
267+
return startIndex;
268+
return FindParentIndex(startIndex - 1, parentType);
269+
}
270+
271+
private Expression GetGroupByParent(MemberExpression parent, Expression source, int index)
272+
{
273+
if (parent.Expression is MemberExpression grandParent)
274+
source = GetGroupByParent(grandParent, source, FindParentIndex(index - 1, grandParent.Type));
275+
return Expression.Call(SingleMethod, GetGroupBy(source, index));
276+
}
277+
278+
private int JoinAndSaveExpression(MemberExpression expression, System.Type type)
279+
{
280+
var join = new NhJoinClause(new NameGenerator(QueryModel).GetNewName(), type, expression);
281+
GetRoot().BodyClauses.Add(join);
282+
return AddIfMissing(new QuerySourceReferenceExpression(join));
283+
}
284+
285+
private int AddIfMissing(Expression expression)
286+
{
287+
if (!Expressions.Contains(expression))
288+
Expressions.Add(expression);
289+
return Expressions.IndexOf(expression);
290+
}
291+
292+
protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression)
293+
{
294+
AddIfMissing(expression);
295+
return Expression.Convert(GetKeyProperty(Source), expression.Type);
296+
}
297+
298+
private static Expression GetKeyProperty(Expression source)
299+
{
300+
return Expression.Property(source, "Key");
301+
}
302+
303+
private static System.Type GetElementType(System.Type type)
304+
{
305+
var elementType = ReflectHelper.GetCollectionElementType(type);
306+
if (elementType == null)
307+
throw new NotSupportedException("Unknown collection type " + type.FullName);
308+
return elementType;
309+
}
310+
311+
private bool IsMappedCollection(MemberInfo memberInfo)
312+
{
313+
var roleName = memberInfo.DeclaringType.FullName + "." + memberInfo.Name;
314+
return _sessionFactory.GetCollectionMetadata(roleName) != null;
315+
}
316+
317+
private bool IsMapped(System.Type type)
318+
{
319+
return _sessionFactory.GetClassMetadata(type.FullName) != null;
320+
}
321+
}
322+
}

0 commit comments

Comments
 (0)