Caching LINQ Queries

Introduction

Some days ago I wrote a post on comparing LINQ expressions where I shared my discoveries and general discontent on how difficult it is do right. The reason I was looking into it was because I wanted to write a LINQ query caching mechanism.

There are actually three main problems involved:

  • Comparing two LINQ expressions;
  • Caching a LINQ expression for some time;
  • For cached instances, preventing them from executing when they are enumerated after the first time.

Comparing LINQ Expressions

I ended up writing my own comparer, which involved looking at all possible types of Expressions (around 20 classes). For some it is quite easy, because they have few relevant properties – take ConstantExpression, for instance, I only had to consider the Value property – while for others there are a few – BinaryExpression has Left, Right, Conversion and Method. OK, so here is the code:

   1: public sealed class ExpressionEqualityComparer : IEqualityComparer<Expression>
   2: {
   3:     #region Private fields
   4:     private Int32 hashCode;
   5:     #endregion
   6:  
   7:     #region Hash code
   8:     private void Visit(Expression expression)
   9:     {
  10:         if (expression == null)
  11:         {
  12:             return;
  13:         }
  14:  
  15:         this.hashCode ^= (Int32) expression.NodeType ^ expression.Type.GetHashCode();
  16:  
  17:         switch (expression.NodeType)
  18:         {
  19:             case ExpressionType.ArrayLength:
  20:             case ExpressionType.Convert:
  21:             case ExpressionType.ConvertChecked:
  22:             case ExpressionType.Negate:
  23:             case ExpressionType.UnaryPlus:
  24:             case ExpressionType.NegateChecked:
  25:             case ExpressionType.Not:
  26:             case ExpressionType.Quote:
  27:             case ExpressionType.TypeAs:
  28:                 this.VisitUnary((UnaryExpression) expression);
  29:                 break;
  30:  
  31:             case ExpressionType.Add:
  32:             case ExpressionType.AddChecked:
  33:             case ExpressionType.And:
  34:             case ExpressionType.AndAlso:
  35:             case ExpressionType.ArrayIndex:
  36:             case ExpressionType.Coalesce:
  37:             case ExpressionType.Divide:
  38:             case ExpressionType.Equal:
  39:             case ExpressionType.ExclusiveOr:
  40:             case ExpressionType.GreaterThan:
  41:             case ExpressionType.GreaterThanOrEqual:
  42:             case ExpressionType.LeftShift:
  43:             case ExpressionType.LessThan:
  44:             case ExpressionType.LessThanOrEqual:
  45:             case ExpressionType.Modulo:
  46:             case ExpressionType.Multiply:
  47:             case ExpressionType.MultiplyChecked:
  48:             case ExpressionType.NotEqual:
  49:             case ExpressionType.Or:
  50:             case ExpressionType.OrElse:
  51:             case ExpressionType.Power:
  52:             case ExpressionType.RightShift:
  53:             case ExpressionType.Subtract:
  54:             case ExpressionType.SubtractChecked:
  55:                 this.VisitBinary((BinaryExpression) expression);
  56:                 break;
  57:  
  58:             case ExpressionType.Call:
  59:                 this.VisitMethodCall((MethodCallExpression) expression);
  60:                 break;
  61:  
  62:             case ExpressionType.Conditional:
  63:                 this.VisitConditional((ConditionalExpression) expression);
  64:                 break;
  65:  
  66:             case ExpressionType.Constant:
  67:                 this.VisitConstant((ConstantExpression) expression);
  68:                 break;
  69:  
  70:             case ExpressionType.Invoke:
  71:                 this.VisitInvocation((InvocationExpression) expression);
  72:                 break;
  73:  
  74:             case ExpressionType.Lambda:
  75:                 this.VisitLambda((LambdaExpression) expression);
  76:                 break;
  77:  
  78:             case ExpressionType.ListInit:
  79:                 this.VisitListInit((ListInitExpression) expression);
  80:                 break;
  81:  
  82:             case ExpressionType.MemberAccess:
  83:                 this.VisitMemberAccess((MemberExpression) expression);
  84:                 break;
  85:  
  86:             case ExpressionType.MemberInit:
  87:                 this.VisitMemberInit((MemberInitExpression) expression);
  88:                 break;
  89:  
  90:             case ExpressionType.New:
  91:                 this.VisitNew((NewExpression) expression);
  92:                 break;
  93:  
  94:             case ExpressionType.NewArrayInit:
  95:             case ExpressionType.NewArrayBounds:
  96:                 this.VisitNewArray((NewArrayExpression) expression);
  97:                 break;
  98:  
  99:             case ExpressionType.Parameter:
 100:                 this.VisitParameter((ParameterExpression) expression);
 101:                 break;
 102:  
 103:             case ExpressionType.TypeIs:
 104:                 this.VisitTypeIs((TypeBinaryExpression) expression);
 105:                 break;
 106:  
 107:             default:
 108:                 throw (new ArgumentException("Unhandled expression type"));
 109:         }
 110:     }
 111:  
 112:     private void VisitUnary(UnaryExpression expression)
 113:     {
 114:         if (expression.Method != null)
 115:         {
 116:             this.hashCode ^= expression.Method.GetHashCode();
 117:         }
 118:  
 119:         this.Visit(expression.Operand);
 120:     }
 121:  
 122:     private void VisitBinary(BinaryExpression expression)
 123:     {
 124:         if (expression.Method != null)
 125:         {
 126:             this.hashCode ^= expression.Method.GetHashCode();
 127:         }
 128:  
 129:         this.Visit(expression.Left);
 130:         this.Visit(expression.Right);
 131:         this.Visit(expression.Conversion);
 132:     }
 133:  
 134:     private void VisitMethodCall(MethodCallExpression expression)
 135:     {
 136:         this.hashCode ^= expression.Method.GetHashCode();
 137:  
 138:         this.Visit(expression.Object);
 139:         this.VisitExpressionList(expression.Arguments);
 140:     }
 141:  
 142:     private void VisitConditional(ConditionalExpression expression)
 143:     {
 144:         this.Visit(expression.Test);
 145:         this.Visit(expression.IfTrue);
 146:         this.Visit(expression.IfFalse);
 147:     }
 148:  
 149:     private void VisitConstant(ConstantExpression expression)
 150:     {
 151:         if (expression.Value != null)
 152:         {
 153:             this.hashCode ^= expression.Value.GetHashCode();
 154:         }
 155:     }
 156:  
 157:     private void VisitInvocation(InvocationExpression expression)
 158:     {
 159:         this.Visit(expression.Expression);
 160:         this.VisitExpressionList(expression.Arguments);
 161:     }
 162:  
 163:     private void VisitLambda(LambdaExpression expression)
 164:     {
 165:         if (expression.Name != null)
 166:         {
 167:             this.hashCode ^= expression.Name.GetHashCode();
 168:         }
 169:  
 170:         this.Visit(expression.Body);
 171:         this.VisitParameterList(expression.Parameters);
 172:     }
 173:  
 174:     private void VisitListInit(ListInitExpression expression)
 175:     {
 176:         this.VisitNew(expression.NewExpression);
 177:         this.VisitElementInitializerList(expression.Initializers);
 178:     }
 179:  
 180:     private void VisitMemberAccess(MemberExpression expression)
 181:     {
 182:         this.hashCode ^= expression.Member.GetHashCode();
 183:         this.Visit(expression.Expression);
 184:     }
 185:  
 186:     private void VisitMemberInit(MemberInitExpression expression)
 187:     {
 188:         this.Visit(expression.NewExpression);
 189:         this.VisitBindingList(expression.Bindings);
 190:     }
 191:  
 192:     private void VisitNew(NewExpression expression)
 193:     {
 194:         this.hashCode ^= expression.Constructor.GetHashCode();
 195:  
 196:         this.VisitMemberList(expression.Members);
 197:         this.VisitExpressionList(expression.Arguments);
 198:     }
 199:  
 200:     private void VisitNewArray(NewArrayExpression expression)
 201:     {
 202:         this.VisitExpressionList(expression.Expressions);
 203:     }
 204:  
 205:     private void VisitParameter(ParameterExpression expression)
 206:     {
 207:         if (expression.Name != null)
 208:         {
 209:             this.hashCode ^= expression.Name.GetHashCode();
 210:         }
 211:     }
 212:  
 213:     private void VisitTypeIs(TypeBinaryExpression expression)
 214:     {
 215:         this.hashCode ^= expression.TypeOperand.GetHashCode();
 216:         this.Visit(expression.Expression);
 217:     }
 218:  
 219:     private void VisitBinding(MemberBinding binding)
 220:     {
 221:         this.hashCode ^= (Int32)binding.BindingType ^ binding.Member.GetHashCode();
 222:  
 223:         switch (binding.BindingType)
 224:         {
 225:             case MemberBindingType.Assignment:
 226:                 this.VisitMemberAssignment((MemberAssignment)binding);
 227:                 break;
 228:  
 229:             case MemberBindingType.MemberBinding:
 230:                 this.VisitMemberMemberBinding((MemberMemberBinding)binding);
 231:                 break;
 232:  
 233:             case MemberBindingType.ListBinding:
 234:                 this.VisitMemberListBinding((MemberListBinding)binding);
 235:                 break;
 236:             
 237:             default:
 238:                 throw (new ArgumentException("Unhandled binding type"));
 239:         }
 240:     }
 241:  
 242:     private void VisitMemberAssignment(MemberAssignment assignment)
 243:     {
 244:         this.Visit(assignment.Expression);
 245:     }
 246:  
 247:     private void VisitMemberMemberBinding(MemberMemberBinding binding)
 248:     {
 249:         this.VisitBindingList(binding.Bindings);
 250:     }
 251:  
 252:     private void VisitMemberListBinding(MemberListBinding binding)
 253:     {
 254:         this.VisitElementInitializerList(binding.Initializers);
 255:     }
 256:  
 257:     private void VisitElementInitializer(ElementInit initializer)
 258:     {
 259:         this.hashCode ^= initializer.AddMethod.GetHashCode();
 260:  
 261:         this.VisitExpressionList(initializer.Arguments);
 262:     }
 263:  
 264:     private void VisitExpressionList(ReadOnlyCollection<Expression> list)
 265:     {
 266:         if (list != null)
 267:         {
 268:             for (Int32 i = 0; i < list.Count; i++)
 269:             {
 270:                 this.Visit(list[i]);
 271:             }
 272:         }
 273:     }
 274:  
 275:     private void VisitParameterList(ReadOnlyCollection<ParameterExpression> list)
 276:     {
 277:         if (list != null)
 278:         {
 279:             for (Int32 i = 0; i < list.Count; i++)
 280:             {
 281:                 this.Visit(list[i]);
 282:             }
 283:         }
 284:     }
 285:  
 286:     private void VisitBindingList(ReadOnlyCollection<MemberBinding> list)
 287:     {
 288:         if (list != null)
 289:         {
 290:             for (Int32 i = 0; i < list.Count; i++)
 291:             {
 292:                 this.VisitBinding(list[i]);
 293:             }
 294:         }
 295:     }
 296:  
 297:     private void VisitElementInitializerList(ReadOnlyCollection<ElementInit> list)
 298:     {
 299:         if (list != null)
 300:         {
 301:             for (Int32 i = 0; i < list.Count; i++)
 302:             {
 303:                 this.VisitElementInitializer(list[i]);
 304:             }
 305:         }
 306:     }
 307:  
 308:     private void VisitMemberList(ReadOnlyCollection<MemberInfo> list)
 309:     {
 310:         if (list != null)
 311:         {
 312:             for (Int32 i = 0; i < list.Count; i++)
 313:             {
 314:                 this.hashCode ^= list[i].GetHashCode();
 315:             }
 316:         }
 317:     }
 318:     #endregion
 319:  
 320:     #region Equality
 321:     private Boolean Visit(Expression x, Expression y)
 322:     {
 323:         if (Object.ReferenceEquals(x, y) == true)
 324:         {
 325:             return (true);
 326:         }
 327:  
 328:         if ((x == null) || (y == null))
 329:         {
 330:             return (false);
 331:         }
 332:  
 333:         if ((x.NodeType != y.NodeType) || (x.Type != y.Type))
 334:         {
 335:             return (false);
 336:         }
 337:  
 338:         switch (x.NodeType)
 339:         {
 340:             case ExpressionType.ArrayLength:
 341:             case ExpressionType.Convert:
 342:             case ExpressionType.ConvertChecked:
 343:             case ExpressionType.Negate:
 344:             case ExpressionType.UnaryPlus:
 345:             case ExpressionType.NegateChecked:
 346:             case ExpressionType.Not:
 347:             case ExpressionType.Quote:
 348:             case ExpressionType.TypeAs:
 349:                 return (this.VisitUnary((UnaryExpression)x, (UnaryExpression)y));
 350:  
 351:             case ExpressionType.Add:
 352:             case ExpressionType.AddChecked:
 353:             case ExpressionType.And:
 354:             case ExpressionType.AndAlso:
 355:             case ExpressionType.ArrayIndex:
 356:             case ExpressionType.Coalesce:
 357:             case ExpressionType.Divide:
 358:             case ExpressionType.Equal:
 359:             case ExpressionType.ExclusiveOr:
 360:             case ExpressionType.GreaterThan:
 361:             case ExpressionType.GreaterThanOrEqual:
 362:             case ExpressionType.LeftShift:
 363:             case ExpressionType.LessThan:
 364:             case ExpressionType.LessThanOrEqual:
 365:             case ExpressionType.Modulo:
 366:             case ExpressionType.Multiply:
 367:             case ExpressionType.MultiplyChecked:
 368:             case ExpressionType.NotEqual:
 369:             case ExpressionType.Or:
 370:             case ExpressionType.OrElse:
 371:             case ExpressionType.Power:
 372:             case ExpressionType.RightShift:
 373:             case ExpressionType.Subtract:
 374:             case ExpressionType.SubtractChecked:
 375:                 return (this.VisitBinary((BinaryExpression)x, (BinaryExpression)y));
 376:  
 377:             case ExpressionType.Call:
 378:                 return (this.VisitMethodCall((MethodCallExpression)x, (MethodCallExpression)y));
 379:  
 380:             case ExpressionType.Conditional:
 381:                 return (this.VisitConditional((ConditionalExpression)x, (ConditionalExpression)y));
 382:  
 383:             case ExpressionType.Constant:
 384:                 return (this.VisitConstant((ConstantExpression)x, (ConstantExpression)y));
 385:  
 386:             case ExpressionType.Invoke:
 387:                 return (this.VisitInvocation((InvocationExpression)x, (InvocationExpression)y));
 388:  
 389:             case ExpressionType.Lambda:
 390:                 return (this.VisitLambda((LambdaExpression)x, (LambdaExpression)y));
 391:  
 392:             case ExpressionType.ListInit:
 393:                 return (this.VisitListInit((ListInitExpression)x, (ListInitExpression)y));
 394:  
 395:             case ExpressionType.MemberAccess:
 396:                 return (this.VisitMemberAccess((MemberExpression)x, (MemberExpression)y));
 397:  
 398:             case ExpressionType.MemberInit:
 399:                 return (this.VisitMemberInit((MemberInitExpression)x, (MemberInitExpression)y));
 400:  
 401:             case ExpressionType.New:
 402:                 return (this.VisitNew((NewExpression)x, (NewExpression)y));
 403:  
 404:             case ExpressionType.NewArrayInit:
 405:             case ExpressionType.NewArrayBounds:
 406:                 return (this.VisitNewArray((NewArrayExpression)x, (NewArrayExpression)y));
 407:  
 408:             case ExpressionType.Parameter:
 409:                 return (this.VisitParameter((ParameterExpression)x, (ParameterExpression)y));
 410:  
 411:             case ExpressionType.TypeIs:
 412:                 return (this.VisitTypeIs((TypeBinaryExpression)x, (TypeBinaryExpression)y));
 413:  
 414:             default:
 415:                 throw (new ArgumentException("Unhandled expression type"));
 416:         }
 417:     }
 418:  
 419:     private Boolean VisitUnary(UnaryExpression x, UnaryExpression y)
 420:     {
 421:         return ((x.Method == y.Method) &&
 422:                (this.Visit(x.Operand, y.Operand)));
 423:     }
 424:  
 425:     private Boolean VisitBinary(BinaryExpression x, BinaryExpression y)
 426:     {
 427:         return ((x.Method == y.Method) &&
 428:                (this.Visit(x.Left, y.Left)) &&
 429:                (this.Visit(x.Right, y.Right)) &&
 430:                (this.Visit(x.Conversion, y.Conversion)));
 431:     }
 432:  
 433:     private Boolean VisitMethodCall(MethodCallExpression x, MethodCallExpression y)
 434:     {
 435:         return ((x.Method == y.Method) &&
 436:                (this.Visit(x.Object, y.Object)) &&
 437:                (this.VisitExpressionList(x.Arguments, y.Arguments)));
 438:     }
 439:  
 440:     private Boolean VisitConditional(ConditionalExpression x, ConditionalExpression y)
 441:     {
 442:         return ((this.Visit(x.Test, y.Test)) &&
 443:                (this.Visit(x.IfTrue, y.IfTrue)) &&
 444:                (this.Visit(x.IfFalse, y.IfFalse)));
 445:     }
 446:  
 447:     private Boolean VisitConstant(ConstantExpression x, ConstantExpression y)
 448:     {
 449:         return (Object.Equals(x.Value, y.Value));
 450:     }
 451:  
 452:     private Boolean VisitInvocation(InvocationExpression x, InvocationExpression y)
 453:     {
 454:         return ((this.Visit(x.Expression, y.Expression)) &&
 455:                (this.VisitExpressionList(x.Arguments, x.Arguments)));
 456:     }
 457:  
 458:     private Boolean VisitLambda(LambdaExpression x, LambdaExpression y)
 459:     {
 460:         return ((this.Visit(x.Body, y.Body)) &&
 461:                (this.VisitParameterList(x.Parameters, y.Parameters)));
 462:     }
 463:  
 464:     private Boolean VisitListInit(ListInitExpression x, ListInitExpression y)
 465:     {
 466:         return ((this.VisitNew(x.NewExpression, y.NewExpression)) &&
 467:                (this.VisitElementInitializerList(x.Initializers, y.Initializers)));
 468:     }
 469:  
 470:     private Boolean VisitMemberAccess(MemberExpression x, MemberExpression y)
 471:     {
 472:         return ((x.Member == y.Member) &&
 473:                (this.Visit(x.Expression, y.Expression)));
 474:     }
 475:  
 476:     private Boolean VisitMemberInit(MemberInitExpression x, MemberInitExpression y)
 477:     {
 478:         return ((this.Visit(x.NewExpression, y.NewExpression)) &&
 479:                (this.VisitBindingList(x.Bindings, y.Bindings)));
 480:     }
 481:  
 482:     private Boolean VisitNew(NewExpression x, NewExpression y)
 483:     {
 484:         return ((x.Constructor == y.Constructor) &&
 485:                (this.VisitMemberList(x.Members, y.Members)) &&
 486:                (this.VisitExpressionList(x.Arguments, y.Arguments)));
 487:     }
 488:  
 489:     private Boolean VisitNewArray(NewArrayExpression x, NewArrayExpression y)
 490:     {
 491:         return (this.VisitExpressionList(x.Expressions, y.Expressions));
 492:     }
 493:  
 494:     private Boolean VisitParameter(ParameterExpression x, ParameterExpression y)
 495:     {
 496:         return ((x.Type == y.Type) && (x.IsByRef == y.IsByRef));
 497:     }
 498:  
 499:     private Boolean VisitTypeIs(TypeBinaryExpression x, TypeBinaryExpression y)
 500:     {
 501:         return ((x.TypeOperand == y.TypeOperand) &&
 502:                (this.Visit(x.Expression, y.Expression)));
 503:     }
 504:  
 505:     private Boolean VisitBinding(MemberBinding x, MemberBinding y)
 506:     {
 507:         if ((x.BindingType != y.BindingType) || (x.Member != y.Member))
 508:         {
 509:             return (false);
 510:         }
 511:  
 512:         switch (x.BindingType)
 513:         {
 514:             case MemberBindingType.Assignment:
 515:                 return (this.VisitMemberAssignment((MemberAssignment)x, (MemberAssignment)y));
 516:  
 517:             case MemberBindingType.MemberBinding:
 518:                 return (this.VisitMemberMemberBinding((MemberMemberBinding)x, (MemberMemberBinding)y));
 519:  
 520:             case MemberBindingType.ListBinding:
 521:                 return (this.VisitMemberListBinding((MemberListBinding)x, (MemberListBinding)y));
 522:  
 523:             default:
 524:                 throw (new ArgumentException("Unhandled binding type"));
 525:         }
 526:     }
 527:  
 528:     private Boolean VisitMemberAssignment(MemberAssignment x, MemberAssignment y)
 529:     {
 530:         return (this.Visit(x.Expression, y.Expression));
 531:     }
 532:  
 533:     private Boolean VisitMemberMemberBinding(MemberMemberBinding x, MemberMemberBinding y)
 534:     {
 535:         return (this.VisitBindingList(x.Bindings, y.Bindings));
 536:     }
 537:  
 538:     private Boolean VisitMemberListBinding(MemberListBinding x, MemberListBinding y)
 539:     {
 540:         return (this.VisitElementInitializerList(x.Initializers, y.Initializers));
 541:     }
 542:  
 543:     private Boolean VisitElementInitializer(ElementInit x, ElementInit y)
 544:     {
 545:         return ((x.AddMethod == y.AddMethod) &&
 546:                (this.VisitExpressionList(x.Arguments, y.Arguments)));
 547:     }
 548:  
 549:     private Boolean VisitExpressionList(ReadOnlyCollection<Expression> x, ReadOnlyCollection<Expression> y)
 550:     {
 551:         if (x == y)
 552:         {
 553:             return (true);
 554:         }
 555:  
 556:         if ((x != null) && (y != null) && (x.Count == y.Count))
 557:         {
 558:             for (Int32 i = 0; i < x.Count; i++)
 559:             {
 560:                 if (this.Visit(x[i], y[i]) == false)
 561:                 {
 562:                     return (false);
 563:                 }
 564:             }
 565:  
 566:             return (true);
 567:         }
 568:  
 569:         return (false);
 570:     }
 571:  
 572:     private Boolean VisitParameterList(ReadOnlyCollection<ParameterExpression> x, ReadOnlyCollection<ParameterExpression> y)
 573:     {
 574:         if (x == y)
 575:         {
 576:             return (true);
 577:         }
 578:  
 579:         if ((x != null) && (y != null) && (x.Count == y.Count))
 580:         {
 581:             for (Int32 i = 0; i < x.Count; i++)
 582:             {
 583:                 if (this.Visit(x[i], y[i]) == false)
 584:                 {
 585:                     return (false);
 586:                 }
 587:             }
 588:  
 589:             return (true);
 590:         }
 591:  
 592:         return (false);
 593:     }
 594:  
 595:     private Boolean VisitBindingList(ReadOnlyCollection<MemberBinding> x, ReadOnlyCollection<MemberBinding> y)
 596:     {
 597:         if (x == y)
 598:         {
 599:             return (true);
 600:         }
 601:  
 602:         if ((x != null) && (y != null) && (x.Count == y.Count))
 603:         {
 604:             for (Int32 i = 0; i < x.Count; i++)
 605:             {
 606:                 if (this.VisitBinding(x[i], y[i]) == false)
 607:                 {
 608:                     return (false);
 609:                 }
 610:             }
 611:  
 612:             return (true);
 613:         }
 614:  
 615:         return (false);
 616:     }
 617:  
 618:     private Boolean VisitElementInitializerList(ReadOnlyCollection<ElementInit> x, ReadOnlyCollection<ElementInit> y)
 619:     {
 620:         if (x == y)
 621:         {
 622:             return (true);
 623:         }
 624:  
 625:         if ((x != null) && (y != null) && (x.Count == y.Count))
 626:         {
 627:             for (Int32 i = 0; i < x.Count; i++)
 628:             {
 629:                 if (this.VisitElementInitializer(x[i], y[i]) == false)
 630:                 {
 631:                     return (false);
 632:                 }
 633:             }
 634:  
 635:             return (true);
 636:         }
 637:  
 638:         return (false);
 639:     }
 640:  
 641:     private Boolean VisitMemberList(ReadOnlyCollection<MemberInfo> x, ReadOnlyCollection<MemberInfo> y)
 642:     {
 643:         if (x == y)
 644:         {
 645:             return (true);
 646:         }
 647:  
 648:         if ((x != null) && (y != null) && (x.Count == y.Count))
 649:         {
 650:             for (Int32 i = 0; i < x.Count; i++)
 651:             {
 652:                 if (x[i] != y[i])
 653:                 {
 654:                     return (false);
 655:                 }
 656:             }
 657:  
 658:             return (true);
 659:         }
 660:  
 661:         return (false);
 662:     }
 663:     #endregion
 664:  
 665:     #region IEqualityComparer<Expression> Members
 666:     public Boolean Equals(Expression x, Expression y)
 667:     {
 668:         return (this.Visit(x, y));
 669:     }
 670:  
 671:     public Int32 GetHashCode(Expression expression)
 672:     {
 673:         this.hashCode = 0;
 674:  
 675:         this.Visit(expression);
 676:  
 677:         return (this.hashCode);
 678:     }
 679:     #endregion
 680: }

This implementation disregards the lambda variable name, so that “x => …” is equal to “y => …”. Also, as you can see, this is not safe for multithreaded usage, because it uses an accumulator field (hashCode), where the hash for the expression currently being calculated is stored. Being an IEqualityComparer<Expression>, it implements both Equals and GetHashCode methods.

Caching

As for caching, I decided to use the MemoryCache implementation of an ObjectCache, available on .NET 4, this way I don’t have any external dependencies:

   1: public static class QueryableExtensions
   2: {
   3:     public static IQueryable<T> AsCacheable<T>(this IQueryable<T> queryable, TimeSpan duration)
   4:     {
   5:         return (AsCacheable(queryable, (Int32) duration.TotalSeconds));
   6:     }
   7:  
   8:     public static IQueryable<T> AsCacheable<T>(this IQueryable<T> queryable, Int32 durationSeconds)
   9:     {
  10:         ObjectCache cache = null;
  11:  
  12:         if (ObjectCache.Host != null)
  13:         {
  14:             cache = ObjectCache.Host.GetService(typeof(ObjectCache)) as ObjectCache;
  15:         }
  16:  
  17:         cache = cache ?? MemoryCache.Default;
  18:  
  19:         IQueryable<T> cachedQuery = new QueryableWrapper<T>(cache, queryable, durationSeconds);
  20:  
  21:         return (cachedQuery);
  22:     }
  23:  
  24:     public static IOrderedQueryable<T> AsCacheable<T>(this IOrderedQueryable<T> queryable, TimeSpan duration)
  25:     {
  26:         return (AsCacheable(queryable as IQueryable<T>, duration) as IOrderedQueryable<T>);
  27:     }
  28:  
  29:     public static IOrderedQueryable<T> AsCacheable<T>(this IOrderedQueryable<T> queryable, Int32 durationSeconds)
  30:     {
  31:         return (AsCacheable(queryable as IQueryable<T>, durationSeconds) as IOrderedQueryable<T>);
  32:     }
  33: }

As you can see, we are free to supply out own ObjectCache implementation, provided we place a IServiceProvider implementation (UnityServiceLocator will do) on the ObjectCache.Host property and this implementation returns a valid ObjectCache instance. Feel free to replace this by any other similar mechanism!

Preventing Multiple Query Executions

So, when an IQueryable<T> is first executed, it will go to the database, or someplace else (just think WCF Data Services’ DataServiceQuery<T>), and return its results. If we are going to put that query in a cache, we want to prevent it from executing multiple times, otherwise the purpose of the cache would be defaced. For that, I built my own class that just inherits from IQueryable<T> (actually, from IOrderedQueryable<T>, for support of ordered queries) and overrides the IEnumerable<T> (of which IQueryable<T> descends) GetEnumerator method:

   1: sealed class QueryableWrapper<T> : IOrderedQueryable<T>
   2: {
   3:     private static readonly ExpressionEqualityComparer comparer = new ExpressionEqualityComparer();
   4:  
   5:     sealed class EnumeratorWrapper : IEnumerator<T>
   6:     {
   7:         private readonly LinkedList<T> list = new LinkedList<T>();
   8:         private QueryableWrapper<T> queryable;
   9:         private IEnumerator<T> enumerator;
  10:         private Boolean stored = false;
  11:         internal Boolean consumed;
  12:  
  13:         public EnumeratorWrapper(QueryableWrapper<T> queryable, IEnumerator<T> enumerator)
  14:         {
  15:             this.enumerator = enumerator;
  16:             this.queryable = queryable;
  17:         }
  18:  
  19:         internal IEnumerator<T> FromCache()
  20:         {
  21:             return (this.list.GetEnumerator());
  22:         }
  23:  
  24:         #region IEnumerator<T> Members
  25:  
  26:         public T Current
  27:         {
  28:             get
  29:             {
  30:                 T current = this.enumerator.Current;
  31:  
  32:                 if (this.stored == false)
  33:                 {
  34:                     this.list.AddLast(current);
  35:                     this.stored = true;
  36:                 }
  37:  
  38:                 return (current);
  39:             }
  40:         }
  41:  
  42:         #endregion
  43:  
  44:         #region IDisposable Members
  45:  
  46:         public void Dispose()
  47:         {
  48:             this.stored = false;
  49:             this.consumed = true;
  50:             this.enumerator.Dispose();
  51:         }
  52:  
  53:         #endregion
  54:  
  55:         #region IEnumerator Members
  56:  
  57:         Object IEnumerator.Current
  58:         {
  59:             get
  60:             {
  61:                 return (this.Current);
  62:             }
  63:         }
  64:  
  65:         public Boolean MoveNext()
  66:         {
  67:             Boolean result = this.enumerator.MoveNext();
  68:  
  69:             if (result == true)
  70:             {
  71:                 this.stored = false;
  72:             }
  73:  
  74:             return (result);
  75:         }
  76:  
  77:         public void Reset()
  78:         {
  79:             this.stored = false;
  80:             this.list.Clear();
  81:             this.enumerator.Reset();
  82:         }
  83:  
  84:         #endregion
  85:     }
  86:  
  87:     #region Private readonly fields
  88:     private readonly IQueryable<T> queryable;
  89:     private readonly ObjectCache cache;
  90:     private readonly Int32 durationSeconds;
  91:     #endregion
  92:  
  93:     #region Internal constructor
  94:     internal QueryableWrapper(ObjectCache cache, IQueryable<T> queryable, Int32 durationSeconds)
  95:     {
  96:         this.cache = cache;
  97:         this.queryable = queryable;
  98:         this.durationSeconds = durationSeconds;
  99:     }
 100:     #endregion
 101:  
 102:     #region IEnumerable<T> Members
 103:  
 104:     public IEnumerator<T> GetEnumerator()
 105:     {
 106:         IEnumerator<T> enumerator = null;
 107:         String key = this.GetKey(this.queryable).ToString();
 108:  
 109:         if (this.cache.Contains(key) == true)
 110:         {
 111:             //hit
 112:             enumerator = this.cache[key] as EnumeratorWrapper;
 113:             if ((enumerator as EnumeratorWrapper).consumed == true)
 114:             {
 115:                 return ((enumerator as EnumeratorWrapper).FromCache());
 116:             }
 117:         }
 118:         else
 119:         {
 120:             //miss
 121:             enumerator = new EnumeratorWrapper(this, this.queryable.GetEnumerator());
 122:             this.cache.Add(key, enumerator, DateTimeOffset.Now.AddSeconds(this.durationSeconds));
 123:         }
 124:  
 125:         return (enumerator);
 126:     }
 127:  
 128:     #endregion
 129:  
 130:     #region IEnumerable Members
 131:  
 132:     IEnumerator IEnumerable.GetEnumerator()
 133:     {
 134:         return (this.GetEnumerator());
 135:     }
 136:  
 137:     #endregion
 138:  
 139:     #region IQueryable Members
 140:  
 141:     public Type ElementType
 142:     {
 143:         get
 144:         {
 145:             return (this.queryable.ElementType);
 146:         }
 147:     }
 148:  
 149:     public Expression Expression
 150:     {
 151:         get
 152:         {
 153:             return (this.queryable.Expression);
 154:         }
 155:     }
 156:  
 157:     public IQueryProvider Provider
 158:     {
 159:         get
 160:         {
 161:             return (this.queryable.Provider);
 162:         }
 163:     }
 164:  
 165:     #endregion
 166:  
 167:     #region Private methods
 168:     private Int32 GetKey(IQueryable queryable)
 169:     {
 170:         return (comparer.GetHashCode(queryable.Expression));
 171:     }
 172:     #endregion
 173: }

Putting It All Together

This allows me to write code as this:

   1: //cache miss
   2: var q1 = ctx.Customers.Where(x => x.Orders.Any()).OrderBy(x => x.Name).AsCacheable(TimeSpan.FromSeconds(10)).ToList();
   3:  
   4: //cache hit
   5: var q2 = ctx.Customers.Where(o => o.Orders.Any()).OrderBy(x => x.Name).AsCacheable(TimeSpan.FromSeconds(10)).ToList();
   6:  
   7: //cache hit
   8: var q3 = (from c in ctx.Customers where c.Orders.Any() orderby c.Name select c).AsCacheable(TimeSpan.FromSeconds(10)).ToList();
   9:  
  10: Thread.Sleep(10000);
  11:  
  12: //cache miss
  13: var q4 = ctx.Customers.Where(x => x.Orders.Any()).OrderBy(x => x.Name).AsCacheable(TimeSpan.FromSeconds(10)).ToList();

By calling the AsCacheable extension method, our LINQ queries get cached for the specified duration. This will work with any LINQ implementation.

                             

9 Comments

  • Looks interesting.
    Is this available as a GitHub project/nuget somewhere?

  • Daniel:
    Will be soon, I hope! Will try to publish most of my projects there. In the meantime, try copy & paste, if you need assistance, just let me know!

  • pls put it in Github, thx. :)

  • How do you handle objects being cached MemeoryCache being shared by all things that retrieve them? Caused me a right headache when I implemented my own caching of queryable results, values would be changed in cached data before being saved back to the db. Ended up cloning the objects into and out of the cache.

  • Mark:
    That can certainly happen. Objects are shared with other classes (namely, Entity Framework contexts, NHibernate sessions), and they can change. This does not address that problem, it merely associates some objects with a LINQ expression, like in a dictionary.

  • Nicely done!

    What is actually used as cache key here? The LINQ expression or the generated SQL?

    Have you tried this approach with compiled LINQ queries?

  • colombia:
    The cache key is the expression hash, as returned from ExpressionEqualityComparer.

  • Thanks for the code. However it looks like the code is not thread safe and a bit buggy. if the condition
    if ((enumerator as EnumeratorWrapper).consumed == true) is not true because someone is still using it, you return enumerator variable which is null/not set. I guess the only safe way is to return this.queryable.GetEnumerator() like if the enumerator is not yet cached. If one would return 'live' enumerator from cache it would get used by two threads so both would get random subset of the result while fighting over it. Also the enumerator code would need to be thread safe too (like adding to the linked list from two threads in Current method).

  • Hi, fanoush!
    Yes, it's not thread safe. Feel free to improve it, and share it with the World! ;-)

Add a Comment

As it will appear on the website

Not displayed

Your website