Filter Collections Automatically With Entity Framework Code First

Introduction

In some O/RMs, it is possible to specify automatic filters for entity collections such as one-to-many or many-to-many. These are applied automatically whenever these collections are being populated. Entity Framework does not offer one such mechanism, however, it is possible to implement it.

Context Collections

In Entity Framework Code First, entities are exposed as IDbSet<T> or DbSet<T> collections on a context, a DbContext-derived class. There is no way to automatically set a filter that will apply to all queries coming from these collections, unless we create our own IDbSet<T> class. Let’s call it FilteredDbSet<T> and have it implement the same interfaces as DbSet<T> so that it can be used instead of it transparently:

   1: public class FilteredDbSet<TEntity> : IDbSet<TEntity>, IOrderedQueryable<TEntity> where TEntity : class
   2: {
   3:     #region Private readonly fields
   4:     private readonly DbSet<TEntity> set;
   5:     private readonly Func<TEntity, Boolean> matchesFilter;
   6:     #endregion
   7:  
   8:     #region Public constructors
   9:     public FilteredDbSet(DbContext context, Expression<Func<TEntity, Boolean>> filter)
  10:     {
  11:         this.set = set;
  12:         this.Filter = filter;
  13:         this.matchesFilter = filter.Compile();
  14:     }
  15:  
  16:     #endregion
  17:  
  18:     #region Public properties
  19:     public Expression<Func<TEntity, Boolean>> Filter
  20:     {
  21:         get;
  22:         protected set;
  23:     }
  24:  
  25:     public IQueryable<TEntity> Unfiltered
  26:     {
  27:         get
  28:         {
  29:             return (this.set);
  30:         }
  31:     }
  32:     #endregion
  33:  
  34:     #region Public methods
  35:     public IQueryable<TEntity> Include(String path)
  36:     {
  37:         return (this.set.Include(path).Where(this.Filter));
  38:     }
  39:  
  40:     public DbSqlQuery<TEntity> SqlQuery(String sql, params Object[] parameters)
  41:     {
  42:         return (this.set.SqlQuery(sql, parameters));
  43:     }
  44:     #endregion
  45:  
  46:     #region IDbSet<TEntity> Members
  47:     TEntity IDbSet<TEntity>.Add(TEntity entity)
  48:     {
  49:         this.ThrowIfEntityDoesNotMatchFilter(entity);
  50:         return (this.set.Add(entity));
  51:     }
  52:  
  53:     TEntity IDbSet<TEntity>.Attach(TEntity entity)
  54:     {
  55:         this.ThrowIfEntityDoesNotMatchFilter(entity);
  56:         return (this.set.Attach(entity));
  57:     }
  58:  
  59:     TDerivedEntity IDbSet<TEntity>.Create<TDerivedEntity>()
  60:     {
  61:         var entity = this.set.Create<TDerivedEntity>();
  62:         return (entity as TDerivedEntity);
  63:     }
  64:  
  65:     TEntity IDbSet<TEntity>.Create()
  66:     {
  67:         var entity = this.set.Create();
  68:         return (entity);
  69:     }
  70:  
  71:     TEntity IDbSet<TEntity>.Find(params Object[] keyValues)
  72:     {
  73:         var entity = this.set.Find(keyValues);
  74:         ThrowIfEntityDoesNotMatchFilter(entity);
  75:         return (entity);
  76:     }
  77:  
  78:     TEntity IDbSet<TEntity>.Remove(TEntity entity)
  79:     {
  80:         ThrowIfEntityDoesNotMatchFilter(entity);
  81:         return (this.set.Remove(entity));
  82:     }
  83:  
  84:     ObservableCollection<TEntity> IDbSet<TEntity>.Local
  85:     {
  86:         get { return (this.set.Local); }
  87:     }
  88:     #endregion
  89:  
  90:     #region IEnumerable<TEntity> Members
  91:     IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
  92:     {
  93:         return (this.set.Where(this.Filter).GetEnumerator());
  94:     }
  95:     #endregion
  96:  
  97:     #region IEnumerable Members
  98:     IEnumerator IEnumerable.GetEnumerator()
  99:     {
 100:         return ((this as IEnumerable<TEntity>).GetEnumerator());
 101:     }
 102:     #endregion
 103:  
 104:     #region IQueryable Members
 105:     Type IQueryable.ElementType
 106:     {
 107:         get { return ((this.set as IQueryable).ElementType); }
 108:     }
 109:  
 110:     Expression IQueryable.Expression
 111:     {
 112:         get
 113:         {
 114:             return (this.set.Where(this.Filter).Expression);
 115:         }
 116:     }
 117:  
 118:     IQueryProvider IQueryable.Provider
 119:     {
 120:         get
 121:         {
 122:             return ((this.set as IQueryable).Provider);
 123:         }
 124:     }
 125:     #endregion
 126:  
 127:     #region Private methods
 128:     private void ThrowIfEntityDoesNotMatchFilter(TEntity entity)
 129:     {
 130:         if ((entity != null) && (this.matchesFilter(entity) == false))
 131:         {
 132:             throw (new ArgumentException("Entity does not match filter", "entity"));
 133:         }
 134:     }
 135:  
 136:     #endregion
 137: }

In the constructor of our DbContext, we create instances of this class, and pass a LINQ restriction query on its constructor:

   1: public MyContext : DbContext
   2: {
   3:     public MyContext()
   4:     {
   5:         this.Bases = new FilteredDbSet<Base>(this, x => x.SomeProperty == 1);
   6:     }
   7:  
   8:     public IDbSet<Base> Bases { get; protected set; }
   9: }

From now on, all queries over the Bases collection will be restricted.

Entity Collections

A different matter is collections on entities. For these, we usually declare a property of ICollection<T> and let Entity Framework create an instance for us, when it is loading the entity. The class responsible for creating this instance is DbCollectionEntry, which unfortunately does not allow subclassing, because it doesn’t have any public or protected constructors or virtual methods. Let’s take a different path and create our own collection class instead:

   1: [Serializable]
   2: public class FilteredCollection<T> : ICollection<T>
   3: {
   4:     private readonly DbCollectionEntry collectionEntry;
   5:     private readonly Func<T, Boolean> compiledFilter;
   6:     private ICollection<T> collection;
   7:  
   8:     public FilteredCollection(ICollection<T> collection, DbCollectionEntry collectionEntry, Expression<Func<T, Boolean>> filter)
   9:     {
  10:         this.Filter = filter;
  11:         this.collection = collection ?? new HashSet<T>();
  12:         this.collectionEntry = collectionEntry;
  13:         this.compiledFilter = filter.Compile();
  14:  
  15:         if (collection != null)
  16:         {
  17:             foreach (T entity in collection)
  18:             {
  19:                 this.collection.Add(entity);
  20:             }
  21:  
  22:             this.collectionEntry.CurrentValue = this;
  23:         }
  24:         else
  25:         {
  26:             this.LoadIfNecessary();
  27:         }
  28:     }
  29:  
  30:     public Expression<Func<T, Boolean>> Filter
  31:     {
  32:         get;
  33:         private set;
  34:     }
  35:  
  36:     protected void ThrowIfInvalid(T entity)
  37:     {
  38:         if (this.compiledFilter(entity) == false)
  39:         {
  40:             throw (new ArgumentException("entity"));
  41:         }
  42:     }
  43:  
  44:     protected void LoadIfNecessary()
  45:     {
  46:         if (this.collectionEntry.IsLoaded == false)
  47:         {
  48:             IQueryable<T> query = this.collectionEntry.Query().Cast<T>().Where(this.Filter);
  49:  
  50:             this.collection = query.ToList();
  51:  
  52:             this.collectionEntry.CurrentValue = this;
  53:  
  54:             var _internalCollectionEntry = this.collectionEntry.GetType().GetField("_internalCollectionEntry", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(this.collectionEntry);
  55:             var _relatedEnd = _internalCollectionEntry.GetType().BaseType.GetField("_relatedEnd", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(_internalCollectionEntry);
  56:             _relatedEnd.GetType().GetField("_isLoaded", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(_relatedEnd, true);
  57:         }
  58:     }
  59:  
  60:     #region ICollection<T> Members
  61:  
  62:     void ICollection<T>.Add(T item)
  63:     {
  64:         this.LoadIfNecessary();
  65:         this.ThrowIfInvalid(item);
  66:         this.collection.Add(item);
  67:     }
  68:  
  69:     void ICollection<T>.Clear()
  70:     {
  71:         this.LoadIfNecessary();
  72:         this.collection.Clear();
  73:     }
  74:  
  75:     Boolean ICollection<T>.Contains(T item)
  76:     {
  77:         this.LoadIfNecessary();
  78:         return (this.collection.Contains(item));
  79:     }
  80:  
  81:     void ICollection<T>.CopyTo(T[] array, Int32 arrayIndex)
  82:     {
  83:         this.LoadIfNecessary();
  84:         this.collection.CopyTo(array, arrayIndex);
  85:     }
  86:  
  87:     Int32 ICollection<T>.Count
  88:     {
  89:         get
  90:         {
  91:             this.LoadIfNecessary();
  92:             return (this.collection.Count);
  93:         }
  94:     }
  95:  
  96:     Boolean ICollection<T>.IsReadOnly
  97:     {
  98:         get
  99:         {
 100:             this.LoadIfNecessary();
 101:             return (this.collection.IsReadOnly);
 102:         }
 103:     }
 104:  
 105:     Boolean ICollection<T>.Remove(T item)
 106:     {
 107:         this.LoadIfNecessary();
 108:         return (this.collection.Remove(item));
 109:     }
 110:  
 111:     #endregion
 112:  
 113:     #region IEnumerable<T> Members
 114:  
 115:     IEnumerator<T> IEnumerable<T>.GetEnumerator()
 116:     {
 117:         this.LoadIfNecessary();
 118:         return (this.collection.GetEnumerator());
 119:     }
 120:  
 121:     #endregion
 122:  
 123:     #region IEnumerable Members
 124:  
 125:     IEnumerator IEnumerable.GetEnumerator()
 126:     {
 127:         return ((this as IEnumerable<T>).GetEnumerator());
 128:     }
 129:  
 130:     #endregion
 131: }

This collection receives a pointer to a possibly existing collection and a DbCollectionEntry responsible for loading this collection. We must use a bit of reflection magic to let DbCollectionEntry think that the collection was already loaded (IsLoaded), and instead load it ourselves, by applying our custom restriction to the expression returned by its Query method.

Now, in order to use this collection, we must intercept the ObjectMaterialized event of the underlying ObjectContext. We set up the filter through an extension method over DbContext:

   1: public static class DbContextExtensions
   2: {
   3:      public static void Filter<TContext, TParentEntity, TCollectionEntity>(this TContext context, Expression<Func<TContext, IDbSet<TParentEntity>>> path, Expression<Func<TParentEntity, ICollection<TCollectionEntity>>> collection, Expression<Func<TCollectionEntity, Boolean>> filter)
   4:         where TContext : DbContext
   5:         where TParentEntity : class, new()
   6:         where TCollectionEntity : class
   7:     {
   8:          (context as IObjectContextAdapter).ObjectContext.ObjectMaterialized += delegate(Object sender, ObjectMaterializedEventArgs e)
   9:         {
  10:             if (e.Entity is TParentEntity)
  11:             {
  12:                 String navigationProperty = collection.ToString().Split('.')[1];
  13:                 DbCollectionEntry col = context.Entry(e.Entity).Collection(navigationProperty);
  14:                 col.CurrentValue = new FilteredCollection<TCollectionEntity>(null, col, filter);
  15:             }
  16:         };
  17:     }
  18: }

The actual ICollection<T> is returned by setting a value to the DbCollectionEntry.CurrentValue property, and we set it to our collection

We use this extension method like this:

   1: ctx.Filter(p => p.Products, p => p.Details, p => p.SomeProperty == 0);

This will pick the Product entity from the Products context collection and filter its Details collection.

Conclusion

As you can see, even though Entity Framework does not have all functionality that we might be used to, it still offers enough extensibility points that allow us to built it ourselves. The same technique that I presented here can be used for building lazy loaded or even IQueryable<T> collections, both interesting ideas that I leave as an exercise to you!Winking smile

                             

2 Comments

  • Hi, I've implemented this solution into our data access code base.
    But I'm getting an exception when it's trying to load the collection from EF.

    Here's the Stacktrace:
    [ArgumentException: entity]
    ITDepartment.Common.Utils.DataAccess.Filtering.FilteredCollection`1.ThrowIfInvalid(T entity) in c:\Projects\Common\dev\ITDepartment.Common.Utils\ITDepartment.Common.Utils.DataAccess\Filtering\FilteredCollection.cs:115
    ITDepartment.Common.Utils.DataAccess.Filtering.FilteredCollection`1.System.Collections.Generic.ICollection<T>.Add(T item) in c:\Projects\Common\dev\ITDepartment.Common.Utils\ITDepartment.Common.Utils.DataAccess\Filtering\FilteredCollection.cs:54
    System.Data.Entity.Core.Objects.Internal.PocoPropertyAccessorStrategy.CollectionAdd(RelatedEnd relatedEnd, Object value) +203

    For some reason, EF tries to re-add the filtered items (DeletedDate != null) to the collection. Do you have any idea on how to fix this?

    Best regards,
    Rick

  • Hi, Rick!
    I don't know, but I can try to have a look... if you want, send me your code (bare minimum to see the problem) to rjperes at hotmail.

Add a Comment

As it will appear on the website

Not displayed

Your website