Intercepting LINQ Queries
A common request when working with LINQ queries (Entity Framework, NHibernate, etc) is the ability to intercept them, that is, inspect an existing query and possibly modify something in it. This is not extremely difficult to do “by hand”, but Microsoft has a nice class called ExpressionVisitor which makes the job easier. It basically has virtual methods that get called whenever the class visits each expression contained in a greater expression, which may come from a query (the IQueryable interface exposes the underlying Expression in its Expression property). The virtual methods even allow returning a replacement for each expression found, the only problem is that you must subclass ExpressionVisitor to make even the slightest change, so I wrote my own class, which exposes all node traversal as events, one event for each kind of expression, where you can return an alternative expression, thus changing the original query. Here is the code for it:
1: public sealed class ExpressionInterceptor : ExpressionVisitor
2: {
3: #region Public events
4: public event Func<BinaryExpression, BinaryExpression> Binary;
5: public event Func<BlockExpression, BlockExpression> Block;
6: public event Func<CatchBlock, CatchBlock> CatchBlock;
7: public event Func<ConditionalExpression, ConditionalExpression> Conditional;
8: public event Func<ConstantExpression, ConstantExpression> Constant;
9: public event Func<DebugInfoExpression, DebugInfoExpression> DebugInfo;
10: public event Func<DefaultExpression, DefaultExpression> Default;
11: public event Func<DynamicExpression, DynamicExpression> Dynamic;
12: public event Func<ElementInit, ElementInit> ElementInit;
13: public event Func<Expression, Expression> Expression;
14: public event Func<Expression, Expression> Extension;
15: public event Func<GotoExpression, GotoExpression> Goto;
16: public event Func<IndexExpression, IndexExpression> Index;
17: public event Func<InvocationExpression, InvocationExpression> Invocation;
18: public event Func<LabelExpression, LabelExpression> Label;
19: public event Func<LabelTarget, LabelTarget> LabelTarget;
20: public event Func<LambdaExpression, LambdaExpression> Lambda;
21: public event Func<ListInitExpression, ListInitExpression> ListInit;
22: public event Func<LoopExpression, LoopExpression> Loop;
23: public event Func<MemberExpression, MemberExpression> Member;
24: public event Func<MemberAssignment, MemberAssignment> MemberAssignment;
25: public event Func<MethodCallExpression, MethodCallExpression> MethodCall;
26: public event Func<MemberInitExpression, MemberInitExpression> MemberInit;
27: public event Func<NewExpression, NewExpression> New;
28: public event Func<NewArrayExpression, NewArrayExpression> NewArray;
29: public event Func<ParameterExpression, ParameterExpression> Parameter;
30: public event Func<RuntimeVariablesExpression, RuntimeVariablesExpression> RuntimeVariables;
31: public event Func<SwitchExpression, SwitchExpression> Switch;
32: public event Func<TryExpression, TryExpression> Try;
33: public event Func<TypeBinaryExpression, TypeBinaryExpression> TypeBinary;
34: public event Func<UnaryExpression, UnaryExpression> Unary;
35: #endregion
36:
37: #region Public methods
38: public IQueryable<T> Visit<T>(IQueryable<T> query)
39: {
40: return (this.Visit(query as IQueryable) as IQueryable<T>);
41: }
42:
43: public IQueryable<T> Visit<T, TExpression>(IQueryable<T> query, Func<TExpression, TExpression> action) where TExpression : Expression
44: {
45: EventInfo evt = this.GetType().GetEvents(BindingFlags.Public | BindingFlags.Instance).Where(x => x.EventHandlerType == typeof(Func<TExpression, TExpression>)).First();
46: evt.AddEventHandler(this, action);
47:
48: query = this.Visit(query);
49:
50: evt.RemoveEventHandler(this, action);
51:
52: return (query);
53: }
54:
55: public IQueryable Visit(IQueryable query)
56: {
57: return (query.Provider.CreateQuery(this.Visit(query.Expression)));
58: }
59:
60: public IEnumerable<Expression> Flatten(IQueryable query)
61: {
62: Queue<Expression> list = new Queue<Expression>();
63: Func<Expression, Expression> action = delegate(Expression expression)
64: {
65: if (expression != null)
66: {
67: list.Enqueue(expression);
68: }
69:
70: return (expression);
71: };
72:
73: this.Expression += action;
74:
75: this.Visit(query);
76:
77: this.Expression -= action;
78:
79: return (list);
80: }
81: #endregion
82:
83: #region Public override methods
84: public override Expression Visit(Expression node)
85: {
86: if ((this.Expression != null) && (node != null))
87: {
88: return(base.Visit(this.Expression(base.Visit(node))));
89: }
90: else
91: {
92: return (base.Visit(node));
93: }
94: }
95: #endregion
96:
97: #region Protected override methods
98: protected override Expression VisitNew(NewExpression node)
99: {
100: if ((this.New != null) && (node != null))
101: {
102: return (base.VisitNew(this.New(node)));
103: }
104: else
105: {
106: return (base.VisitNew(node));
107: }
108: }
109:
110: protected override Expression VisitNewArray(NewArrayExpression node)
111: {
112: if ((this.NewArray != null) && (node != null))
113: {
114: return (base.VisitNewArray(this.NewArray(node)));
115: }
116: else
117: {
118: return (base.VisitNewArray(node));
119: }
120: }
121:
122: protected override Expression VisitParameter(ParameterExpression node)
123: {
124: if ((this.Parameter != null) && (node != null))
125: {
126: return (base.VisitParameter(this.Parameter(node)));
127: }
128: else
129: {
130: return (base.VisitParameter(node));
131: }
132: }
133:
134: protected override Expression VisitRuntimeVariables(RuntimeVariablesExpression node)
135: {
136: if ((this.RuntimeVariables != null) && (node != null))
137: {
138: return (base.VisitRuntimeVariables(this.RuntimeVariables(node)));
139: }
140: else
141: {
142: return (base.VisitRuntimeVariables(node));
143: }
144: }
145:
146: protected override Expression VisitSwitch(SwitchExpression node)
147: {
148: if ((this.Switch != null) && (node != null))
149: {
150: return (base.VisitSwitch(this.Switch(node)));
151: }
152: else
153: {
154: return (base.VisitSwitch(node));
155: }
156: }
157:
158: protected override Expression VisitTry(TryExpression node)
159: {
160: if ((this.Try != null) && (node != null))
161: {
162: return (base.VisitTry(this.Try(node)));
163: }
164: else
165: {
166: return (base.VisitTry(node));
167: }
168: }
169:
170: protected override Expression VisitTypeBinary(TypeBinaryExpression node)
171: {
172: if ((this.TypeBinary != null) && (node != null))
173: {
174: return (base.VisitTypeBinary(this.TypeBinary(node)));
175: }
176: else
177: {
178: return (base.VisitTypeBinary(node));
179: }
180: }
181:
182: protected override Expression VisitUnary(UnaryExpression node)
183: {
184: if ((this.Unary != null) && (node != null))
185: {
186: return (base.VisitUnary(this.Unary(node)));
187: }
188: else
189: {
190: return (base.VisitUnary(node));
191: }
192: }
193:
194: protected override Expression VisitMemberInit(MemberInitExpression node)
195: {
196: if ((this.MemberInit != null) && (node != null))
197: {
198: return (base.VisitMemberInit(this.MemberInit(node)));
199: }
200: else
201: {
202: return (base.VisitMemberInit(node));
203: }
204: }
205:
206: protected override Expression VisitMethodCall(MethodCallExpression node)
207: {
208: if ((this.MethodCall != null) && (node != null))
209: {
210: return (base.VisitMethodCall(this.MethodCall(node)));
211: }
212: else
213: {
214: return (base.VisitMethodCall(node));
215: }
216: }
217:
218:
219: protected override Expression VisitLambda<T>(Expression<T> node)
220: {
221: if ((this.Lambda != null) && (node != null))
222: {
223: return (base.VisitLambda<T>(this.Lambda(node) as Expression<T>));
224: }
225: else
226: {
227: return (base.VisitLambda<T>(node));
228: }
229: }
230:
231: protected override Expression VisitBinary(BinaryExpression node)
232: {
233: if ((this.Binary != null) && (node != null))
234: {
235: return (base.VisitBinary(this.Binary(node)));
236: }
237: else
238: {
239: return (base.VisitBinary(node));
240: }
241: }
242:
243: protected override Expression VisitBlock(BlockExpression node)
244: {
245: if ((this.Block != null) && (node != null))
246: {
247: return (base.VisitBlock(this.Block(node)));
248: }
249: else
250: {
251: return (base.VisitBlock(node));
252: }
253: }
254:
255: protected override CatchBlock VisitCatchBlock(CatchBlock node)
256: {
257: if ((this.CatchBlock != null) && (node != null))
258: {
259: return (base.VisitCatchBlock(this.CatchBlock(node)));
260: }
261: else
262: {
263: return (base.VisitCatchBlock(node));
264: }
265: }
266:
267: protected override Expression VisitConditional(ConditionalExpression node)
268: {
269: if ((this.Conditional != null) && (node != null))
270: {
271: return (base.VisitConditional(this.Conditional(node)));
272: }
273: else
274: {
275: return (base.VisitConditional(node));
276: }
277: }
278:
279: protected override Expression VisitConstant(ConstantExpression node)
280: {
281: if ((this.Constant != null) && (node != null))
282: {
283: return (base.VisitConstant(this.Constant(node)));
284: }
285: else
286: {
287: return (base.VisitConstant(node));
288: }
289: }
290:
291: protected override Expression VisitDebugInfo(DebugInfoExpression node)
292: {
293: if ((this.DebugInfo != null) && (node != null))
294: {
295: return (base.VisitDebugInfo(this.DebugInfo(node)));
296: }
297: else
298: {
299: return (base.VisitDebugInfo(node));
300: }
301: }
302:
303: protected override Expression VisitDefault(DefaultExpression node)
304: {
305: if ((this.Default != null) && (node != null))
306: {
307: return (base.VisitDefault(this.Default(node)));
308: }
309: else
310: {
311: return (base.VisitDefault(node));
312: }
313: }
314:
315: protected override Expression VisitDynamic(DynamicExpression node)
316: {
317: if ((this.Dynamic != null) && (node != null))
318: {
319: return (base.VisitDynamic(this.Dynamic(node)));
320: }
321: else
322: {
323: return (base.VisitDynamic(node));
324: }
325: }
326:
327: protected override ElementInit VisitElementInit(ElementInit node)
328: {
329: if ((this.ElementInit != null) && (node != null))
330: {
331: return (base.VisitElementInit(this.ElementInit(node)));
332: }
333: else
334: {
335: return (base.VisitElementInit(node));
336: }
337: }
338:
339: protected override Expression VisitExtension(Expression node)
340: {
341: if ((this.Extension != null) && (node != null))
342: {
343: return (base.VisitExtension(this.Extension(node)));
344: }
345: else
346: {
347: return (base.VisitExtension(node));
348: }
349: }
350:
351: protected override Expression VisitGoto(GotoExpression node)
352: {
353: if ((this.Goto != null) && (node != null))
354: {
355: return (base.VisitGoto(this.Goto(node)));
356: }
357: else
358: {
359: return (base.VisitGoto(node));
360: }
361: }
362:
363: protected override Expression VisitIndex(IndexExpression node)
364: {
365: if ((this.Index != null) && (node != null))
366: {
367: return (base.VisitIndex(this.Index(node)));
368: }
369: else
370: {
371: return (base.VisitIndex(node));
372: }
373: }
374:
375: protected override Expression VisitInvocation(InvocationExpression node)
376: {
377: if ((this.Invocation != null) && (node != null))
378: {
379: return (base.VisitInvocation(this.Invocation(node)));
380: }
381: else
382: {
383: return (base.VisitInvocation(node));
384: }
385: }
386:
387: protected override Expression VisitLabel(LabelExpression node)
388: {
389: if ((this.Label != null) && (node != null))
390: {
391: return (base.VisitLabel(this.Label(node)));
392: }
393: else
394: {
395: return (base.VisitLabel(node));
396: }
397: }
398:
399: protected override LabelTarget VisitLabelTarget(LabelTarget node)
400: {
401: if ((this.LabelTarget != null) && (node != null))
402: {
403: return (base.VisitLabelTarget(this.LabelTarget(node)));
404: }
405: else
406: {
407: return (base.VisitLabelTarget(node));
408: }
409: }
410:
411: protected override Expression VisitListInit(ListInitExpression node)
412: {
413: if ((this.ListInit != null) && (node != null))
414: {
415: return (base.VisitListInit(this.ListInit(node)));
416: }
417: else
418: {
419: return (base.VisitListInit(node));
420: }
421: }
422:
423: protected override Expression VisitLoop(LoopExpression node)
424: {
425: if ((this.Loop != null) && (node != null))
426: {
427: return (base.VisitLoop(this.Loop(node)));
428: }
429: else
430: {
431: return (base.VisitLoop(node));
432: }
433: }
434:
435: protected override Expression VisitMember(MemberExpression node)
436: {
437: if ((this.Member != null) && (node != null))
438: {
439: return (base.VisitMember(this.Member(node)));
440: }
441: else
442: {
443: return (base.VisitMember(node));
444: }
445: }
446:
447: protected override MemberAssignment VisitMemberAssignment(MemberAssignment node)
448: {
449: if ((this.MemberAssignment != null) && (node != null))
450: {
451: return (base.VisitMemberAssignment(this.MemberAssignment(node)));
452: }
453: else
454: {
455: return (base.VisitMemberAssignment(node));
456: }
457: }
458: #endregion
459: }
Yes, I know, I probably should have used properties instead of events, but that’s really not important.
A simple example might be:
1: ExpressionInterceptor interceptor = new ExpressionInterceptor();
2: String[] lettersArray = new String[] { "A", "B", "C" }; //a data source
3: IQueryable<String> lettersQuery = lettersArray.AsQueryable().Where(x => x == "A").OrderByDescending(x => x).Select(x => x.ToUpper()); //a silly query
4: IQueryable<String> lettersInterceptedQuery = interceptor.Visit<String, MethodCallExpression>(lettersQuery, x =>
5: {
6: if (x.Method.Name == "ToUpper")
7: {
8: //change from uppercase to lowercase
9: x = Expression.Call(x.Object, typeof(String).GetMethods().Where(y => y.Name == "ToLower").First());
10: }
11:
12: return (x);
13: });
14: lettersInterceptedQuery = interceptor.Visit<String, BinaryExpression>(lettersInterceptedQuery, x =>
15: {
16: //change from qual to not equal
17: x = Expression.MakeBinary(ExpressionType.NotEqual, x.Left, x.Right);
18:
19: return (x);
20: });
21: IEnumerable<Expression> lettersExpressions = interceptor.Flatten(lettersQuery); //all expressions found
22: IEnumerable<String> lettersList = lettersQuery.ToList(); //"A"
23: IEnumerable<String> lettersInterceptedList = lettersInterceptedQuery.ToList(); //"c", "b"
You see, I have methods that visit both an IQueryable, an IQueryable<T> or an Expression, and there is even an inline version that takes a Func<TExpression, TExpression> for even easier usage.
As always, hope you find it useful!