Intercepting LINQ Queries

A common request when working with LINQ queries (Entity Framework, NHibernate, etc) is the ability to intercept them, that is, inspect an existing query and possibly modify something in it. This is not extremely difficult to do “by hand”, but Microsoft has a nice class called ExpressionVisitor which makes the job easier. It basically has virtual methods that get called whenever the class visits each expression contained in a greater expression, which may come from a query (the IQueryable interface exposes the underlying Expression in its Expression property). The virtual methods even allow returning a replacement for each expression found, the only problem is that you must subclass ExpressionVisitor to make even the slightest change, so I wrote my own class, which exposes all node traversal as events, one event for each kind of expression, where you can return an alternative expression, thus changing the original query. Here is the code for it:

   1: public sealed class ExpressionInterceptor : ExpressionVisitor
   2: {
   3:     #region Public events
   4:     public event Func<BinaryExpression, BinaryExpression> Binary;
   5:     public event Func<BlockExpression, BlockExpression> Block;
   6:     public event Func<CatchBlock, CatchBlock> CatchBlock;
   7:     public event Func<ConditionalExpression, ConditionalExpression> Conditional;
   8:     public event Func<ConstantExpression, ConstantExpression> Constant;
   9:     public event Func<DebugInfoExpression, DebugInfoExpression> DebugInfo;
  10:     public event Func<DefaultExpression, DefaultExpression> Default;
  11:     public event Func<DynamicExpression, DynamicExpression> Dynamic;
  12:     public event Func<ElementInit, ElementInit> ElementInit;
  13:     public event Func<Expression, Expression> Expression;
  14:     public event Func<Expression, Expression> Extension;
  15:     public event Func<GotoExpression, GotoExpression> Goto;
  16:     public event Func<IndexExpression, IndexExpression> Index;
  17:     public event Func<InvocationExpression, InvocationExpression> Invocation;
  18:     public event Func<LabelExpression, LabelExpression> Label;
  19:     public event Func<LabelTarget, LabelTarget> LabelTarget;
  20:     public event Func<LambdaExpression, LambdaExpression> Lambda;
  21:     public event Func<ListInitExpression, ListInitExpression> ListInit;
  22:     public event Func<LoopExpression, LoopExpression> Loop;
  23:     public event Func<MemberExpression, MemberExpression> Member;
  24:     public event Func<MemberAssignment, MemberAssignment> MemberAssignment;
  25:     public event Func<MethodCallExpression, MethodCallExpression> MethodCall;
  26:     public event Func<MemberInitExpression, MemberInitExpression> MemberInit;
  27:     public event Func<NewExpression, NewExpression> New;
  28:     public event Func<NewArrayExpression, NewArrayExpression> NewArray;
  29:     public event Func<ParameterExpression, ParameterExpression> Parameter;
  30:     public event Func<RuntimeVariablesExpression, RuntimeVariablesExpression> RuntimeVariables;
  31:     public event Func<SwitchExpression, SwitchExpression> Switch;
  32:     public event Func<TryExpression, TryExpression> Try;
  33:     public event Func<TypeBinaryExpression, TypeBinaryExpression> TypeBinary;
  34:     public event Func<UnaryExpression, UnaryExpression> Unary;
  35:     #endregion
  36:  
  37:     #region Public methods
  38:     public IQueryable<T> Visit<T>(IQueryable<T> query)
  39:     {
  40:         return (this.Visit(query as IQueryable) as IQueryable<T>);
  41:     }
  42:  
  43:     public IQueryable<T> Visit<T, TExpression>(IQueryable<T> query, Func<TExpression, TExpression> action) where TExpression : Expression
  44:     {
  45:         EventInfo evt = this.GetType().GetEvents(BindingFlags.Public | BindingFlags.Instance).Where(x => x.EventHandlerType == typeof(Func<TExpression, TExpression>)).First();
  46:         evt.AddEventHandler(this, action);
  47:  
  48:         query = this.Visit(query);
  49:  
  50:         evt.RemoveEventHandler(this, action);
  51:  
  52:         return (query);
  53:     }
  54:  
  55:     public IQueryable Visit(IQueryable query)
  56:     {
  57:         return (query.Provider.CreateQuery(this.Visit(query.Expression)));
  58:     }
  59:  
  60:     public IEnumerable<Expression> Flatten(IQueryable query)
  61:     {
  62:         Queue<Expression> list = new Queue<Expression>();
  63:         Func<Expression, Expression> action = delegate(Expression expression)
  64:         {
  65:             if (expression != null)
  66:             {
  67:                 list.Enqueue(expression);
  68:             }
  69:  
  70:             return (expression);
  71:         };
  72:  
  73:         this.Expression += action;
  74:  
  75:         this.Visit(query);
  76:  
  77:         this.Expression -= action;
  78:  
  79:         return (list);
  80:     }
  81:     #endregion
  82:  
  83:     #region Public override methods
  84:     public override Expression Visit(Expression node)
  85:     {
  86:         if ((this.Expression != null) && (node != null))
  87:         {
  88:             return(base.Visit(this.Expression(base.Visit(node))));
  89:         }
  90:         else
  91:         {
  92:             return (base.Visit(node));
  93:         }
  94:     }
  95:     #endregion
  96:  
  97:     #region Protected override methods
  98:     protected override Expression VisitNew(NewExpression node)
  99:     {
 100:         if ((this.New != null) && (node != null))
 101:         {
 102:             return (base.VisitNew(this.New(node)));
 103:         }
 104:         else
 105:         {
 106:             return (base.VisitNew(node));
 107:         }
 108:     }
 109:  
 110:     protected override Expression VisitNewArray(NewArrayExpression node)
 111:     {
 112:         if ((this.NewArray != null) && (node != null))
 113:         {
 114:             return (base.VisitNewArray(this.NewArray(node)));
 115:         }
 116:         else
 117:         {
 118:             return (base.VisitNewArray(node));
 119:         }
 120:     }
 121:  
 122:     protected override Expression VisitParameter(ParameterExpression node)
 123:     {
 124:         if ((this.Parameter != null) && (node != null))
 125:         {
 126:             return (base.VisitParameter(this.Parameter(node)));
 127:         }
 128:         else
 129:         {
 130:             return (base.VisitParameter(node));
 131:         }
 132:     }
 133:  
 134:     protected override Expression VisitRuntimeVariables(RuntimeVariablesExpression node)
 135:     {
 136:         if ((this.RuntimeVariables != null) && (node != null))
 137:         {
 138:             return (base.VisitRuntimeVariables(this.RuntimeVariables(node)));
 139:         }
 140:         else
 141:         {
 142:             return (base.VisitRuntimeVariables(node));
 143:         }
 144:     }
 145:  
 146:     protected override Expression VisitSwitch(SwitchExpression node)
 147:     {
 148:         if ((this.Switch != null) && (node != null))
 149:         {
 150:             return (base.VisitSwitch(this.Switch(node)));
 151:         }
 152:         else
 153:         {
 154:             return (base.VisitSwitch(node));
 155:         }
 156:     }
 157:  
 158:     protected override Expression VisitTry(TryExpression node)
 159:     {
 160:         if ((this.Try != null) && (node != null))
 161:         {
 162:             return (base.VisitTry(this.Try(node)));
 163:         }
 164:         else
 165:         {
 166:             return (base.VisitTry(node));
 167:         }
 168:     }
 169:  
 170:     protected override Expression VisitTypeBinary(TypeBinaryExpression node)
 171:     {
 172:         if ((this.TypeBinary != null) && (node != null))
 173:         {
 174:             return (base.VisitTypeBinary(this.TypeBinary(node)));
 175:         }
 176:         else
 177:         {
 178:             return (base.VisitTypeBinary(node));
 179:         }
 180:     }
 181:  
 182:     protected override Expression VisitUnary(UnaryExpression node)
 183:     {
 184:         if ((this.Unary != null) && (node != null))
 185:         {
 186:             return (base.VisitUnary(this.Unary(node)));
 187:         }
 188:         else
 189:         {
 190:             return (base.VisitUnary(node));
 191:         }
 192:     }
 193:  
 194:     protected override Expression VisitMemberInit(MemberInitExpression node)
 195:     {
 196:         if ((this.MemberInit != null) && (node != null))
 197:         {
 198:             return (base.VisitMemberInit(this.MemberInit(node)));
 199:         }
 200:         else
 201:         {
 202:             return (base.VisitMemberInit(node));
 203:         }
 204:     }
 205:  
 206:     protected override Expression VisitMethodCall(MethodCallExpression node)
 207:     {
 208:         if ((this.MethodCall != null) && (node != null))
 209:         {
 210:             return (base.VisitMethodCall(this.MethodCall(node)));
 211:         }
 212:         else
 213:         {
 214:             return (base.VisitMethodCall(node));
 215:         }
 216:     }
 217:  
 218:     
 219:     protected override Expression VisitLambda<T>(Expression<T> node)
 220:     {
 221:         if ((this.Lambda != null) && (node != null))
 222:         {
 223:             return (base.VisitLambda<T>(this.Lambda(node) as Expression<T>));
 224:         }
 225:         else
 226:         {
 227:             return (base.VisitLambda<T>(node));
 228:         }
 229:     }
 230:  
 231:     protected override Expression VisitBinary(BinaryExpression node)
 232:     {
 233:         if ((this.Binary != null) && (node != null))
 234:         {
 235:             return (base.VisitBinary(this.Binary(node)));
 236:         }
 237:         else
 238:         {
 239:             return (base.VisitBinary(node));
 240:         }
 241:     }
 242:  
 243:     protected override Expression VisitBlock(BlockExpression node)
 244:     {
 245:         if ((this.Block != null) && (node != null))
 246:         {
 247:             return (base.VisitBlock(this.Block(node)));
 248:         }
 249:         else
 250:         {
 251:             return (base.VisitBlock(node));
 252:         }
 253:     }
 254:  
 255:     protected override CatchBlock VisitCatchBlock(CatchBlock node)
 256:     {
 257:         if ((this.CatchBlock != null) && (node != null))
 258:         {
 259:             return (base.VisitCatchBlock(this.CatchBlock(node)));
 260:         }
 261:         else
 262:         {
 263:             return (base.VisitCatchBlock(node));
 264:         }
 265:     }
 266:  
 267:     protected override Expression VisitConditional(ConditionalExpression node)
 268:     {
 269:         if ((this.Conditional != null) && (node != null))
 270:         {
 271:             return (base.VisitConditional(this.Conditional(node)));
 272:         }
 273:         else
 274:         {
 275:             return (base.VisitConditional(node));
 276:         }
 277:     }
 278:  
 279:     protected override Expression VisitConstant(ConstantExpression node)
 280:     {
 281:         if ((this.Constant != null) && (node != null))
 282:         {
 283:             return (base.VisitConstant(this.Constant(node)));
 284:         }
 285:         else
 286:         {
 287:             return (base.VisitConstant(node));
 288:         }
 289:     }
 290:  
 291:     protected override Expression VisitDebugInfo(DebugInfoExpression node)
 292:     {
 293:         if ((this.DebugInfo != null) && (node != null))
 294:         {
 295:             return (base.VisitDebugInfo(this.DebugInfo(node)));
 296:         }
 297:         else
 298:         {
 299:             return (base.VisitDebugInfo(node));
 300:         }
 301:     }
 302:  
 303:     protected override Expression VisitDefault(DefaultExpression node)
 304:     {
 305:         if ((this.Default != null) && (node != null))
 306:         {
 307:             return (base.VisitDefault(this.Default(node)));
 308:         }
 309:         else
 310:         {
 311:             return (base.VisitDefault(node));
 312:         }
 313:     }
 314:  
 315:     protected override Expression VisitDynamic(DynamicExpression node)
 316:     {
 317:         if ((this.Dynamic != null) && (node != null))
 318:         {
 319:             return (base.VisitDynamic(this.Dynamic(node)));
 320:         }
 321:         else
 322:         {
 323:             return (base.VisitDynamic(node));
 324:         }
 325:     }
 326:  
 327:     protected override ElementInit VisitElementInit(ElementInit node)
 328:     {
 329:         if ((this.ElementInit != null) && (node != null))
 330:         {
 331:             return (base.VisitElementInit(this.ElementInit(node)));
 332:         }
 333:         else
 334:         {
 335:             return (base.VisitElementInit(node));
 336:         }
 337:     }
 338:  
 339:     protected override Expression VisitExtension(Expression node)
 340:     {
 341:         if ((this.Extension != null) && (node != null))
 342:         {
 343:             return (base.VisitExtension(this.Extension(node)));
 344:         }
 345:         else
 346:         {
 347:             return (base.VisitExtension(node));
 348:         }
 349:     }
 350:  
 351:     protected override Expression VisitGoto(GotoExpression node)
 352:     {
 353:         if ((this.Goto != null) && (node != null))
 354:         {
 355:             return (base.VisitGoto(this.Goto(node)));
 356:         }
 357:         else
 358:         {
 359:             return (base.VisitGoto(node));
 360:         }
 361:     }
 362:  
 363:     protected override Expression VisitIndex(IndexExpression node)
 364:     {
 365:         if ((this.Index != null) && (node != null))
 366:         {
 367:             return (base.VisitIndex(this.Index(node)));
 368:         }
 369:         else
 370:         {
 371:             return (base.VisitIndex(node));
 372:         }
 373:     }
 374:  
 375:     protected override Expression VisitInvocation(InvocationExpression node)
 376:     {
 377:         if ((this.Invocation != null) && (node != null))
 378:         {
 379:             return (base.VisitInvocation(this.Invocation(node)));
 380:         }
 381:         else
 382:         {
 383:             return (base.VisitInvocation(node));
 384:         }
 385:     }
 386:  
 387:     protected override Expression VisitLabel(LabelExpression node)
 388:     {
 389:         if ((this.Label != null) && (node != null))
 390:         {
 391:             return (base.VisitLabel(this.Label(node)));
 392:         }
 393:         else
 394:         {
 395:             return (base.VisitLabel(node));
 396:         }
 397:     }
 398:  
 399:     protected override LabelTarget VisitLabelTarget(LabelTarget node)
 400:     {
 401:         if ((this.LabelTarget != null) && (node != null))
 402:         {
 403:             return (base.VisitLabelTarget(this.LabelTarget(node)));
 404:         }
 405:         else
 406:         {
 407:             return (base.VisitLabelTarget(node));
 408:         }
 409:     }
 410:  
 411:     protected override Expression VisitListInit(ListInitExpression node)
 412:     {
 413:         if ((this.ListInit != null) && (node != null))
 414:         {
 415:             return (base.VisitListInit(this.ListInit(node)));
 416:         }
 417:         else
 418:         {
 419:             return (base.VisitListInit(node));
 420:         }
 421:     }
 422:  
 423:     protected override Expression VisitLoop(LoopExpression node)
 424:     {
 425:         if ((this.Loop != null) && (node != null))
 426:         {
 427:             return (base.VisitLoop(this.Loop(node)));
 428:         }
 429:         else
 430:         {
 431:             return (base.VisitLoop(node));
 432:         }
 433:     }
 434:  
 435:     protected override Expression VisitMember(MemberExpression node)
 436:     {
 437:         if ((this.Member != null) && (node != null))
 438:         {
 439:             return (base.VisitMember(this.Member(node)));
 440:         }
 441:         else
 442:         {
 443:             return (base.VisitMember(node));
 444:         }
 445:     }
 446:  
 447:     protected override MemberAssignment VisitMemberAssignment(MemberAssignment node)
 448:     {
 449:         if ((this.MemberAssignment != null) && (node != null))
 450:         {
 451:             return (base.VisitMemberAssignment(this.MemberAssignment(node)));
 452:         }
 453:         else
 454:         {
 455:             return (base.VisitMemberAssignment(node));
 456:         }
 457:     }
 458:     #endregion        
 459: }

Yes, I know, I probably should have used properties instead of events, but that’s really not important.

A simple example might be:

   1: ExpressionInterceptor interceptor = new ExpressionInterceptor();
   2: String[] lettersArray = new String[] { "A", "B", "C" };    //a data source
   3: IQueryable<String> lettersQuery = lettersArray.AsQueryable().Where(x => x == "A").OrderByDescending(x => x).Select(x => x.ToUpper());    //a silly query
   4: IQueryable<String> lettersInterceptedQuery = interceptor.Visit<String, MethodCallExpression>(lettersQuery, x =>
   5: {
   6:     if (x.Method.Name == "ToUpper")
   7:     {
   8:         //change from uppercase to lowercase
   9:         x = Expression.Call(x.Object, typeof(String).GetMethods().Where(y => y.Name == "ToLower").First());
  10:     }
  11:  
  12:     return (x);
  13: });
  14: lettersInterceptedQuery = interceptor.Visit<String, BinaryExpression>(lettersInterceptedQuery, x =>
  15: {
  16:     //change from qual to not equal
  17:     x = Expression.MakeBinary(ExpressionType.NotEqual, x.Left, x.Right);
  18:  
  19:     return (x);
  20: });
  21: IEnumerable<Expression> lettersExpressions = interceptor.Flatten(lettersQuery);    //all expressions found
  22: IEnumerable<String> lettersList = lettersQuery.ToList();    //"A"
  23: IEnumerable<String> lettersInterceptedList = lettersInterceptedQuery.ToList();    //"c", "b"

You see, I have methods that visit both an IQueryable, an IQueryable<T> or an Expression, and there is even an inline version that takes a Func<TExpression, TExpression> for even easier usage.

As always, hope you find it useful!

                             

No Comments