Filter Collections Automatically With Entity Framework Code First
Introduction
In some O/RMs, it is possible to specify automatic filters for entity collections such as one-to-many or many-to-many. These are applied automatically whenever these collections are being populated. Entity Framework does not offer one such mechanism, however, it is possible to implement it.
Context Collections
In Entity Framework Code First, entities are exposed as IDbSet<T> or DbSet<T> collections on a context, a DbContext-derived class. There is no way to automatically set a filter that will apply to all queries coming from these collections, unless we create our own IDbSet<T> class. Let’s call it FilteredDbSet<T> and have it implement the same interfaces as DbSet<T> so that it can be used instead of it transparently:
1: public class FilteredDbSet<TEntity> : IDbSet<TEntity>, IOrderedQueryable<TEntity> where TEntity : class
2: {
3: #region Private readonly fields
4: private readonly DbSet<TEntity> set;
5: private readonly Func<TEntity, Boolean> matchesFilter;
6: #endregion
7:
8: #region Public constructors
9: public FilteredDbSet(DbContext context, Expression<Func<TEntity, Boolean>> filter)
10: {
11: this.set = set;
12: this.Filter = filter;
13: this.matchesFilter = filter.Compile();
14: }
15:
16: #endregion
17:
18: #region Public properties
19: public Expression<Func<TEntity, Boolean>> Filter
20: {
21: get;
22: protected set;
23: }
24:
25: public IQueryable<TEntity> Unfiltered
26: {
27: get
28: {
29: return (this.set);
30: }
31: }
32: #endregion
33:
34: #region Public methods
35: public IQueryable<TEntity> Include(String path)
36: {
37: return (this.set.Include(path).Where(this.Filter));
38: }
39:
40: public DbSqlQuery<TEntity> SqlQuery(String sql, params Object[] parameters)
41: {
42: return (this.set.SqlQuery(sql, parameters));
43: }
44: #endregion
45:
46: #region IDbSet<TEntity> Members
47: TEntity IDbSet<TEntity>.Add(TEntity entity)
48: {
49: this.ThrowIfEntityDoesNotMatchFilter(entity);
50: return (this.set.Add(entity));
51: }
52:
53: TEntity IDbSet<TEntity>.Attach(TEntity entity)
54: {
55: this.ThrowIfEntityDoesNotMatchFilter(entity);
56: return (this.set.Attach(entity));
57: }
58:
59: TDerivedEntity IDbSet<TEntity>.Create<TDerivedEntity>()
60: {
61: var entity = this.set.Create<TDerivedEntity>();
62: return (entity as TDerivedEntity);
63: }
64:
65: TEntity IDbSet<TEntity>.Create()
66: {
67: var entity = this.set.Create();
68: return (entity);
69: }
70:
71: TEntity IDbSet<TEntity>.Find(params Object[] keyValues)
72: {
73: var entity = this.set.Find(keyValues);
74: ThrowIfEntityDoesNotMatchFilter(entity);
75: return (entity);
76: }
77:
78: TEntity IDbSet<TEntity>.Remove(TEntity entity)
79: {
80: ThrowIfEntityDoesNotMatchFilter(entity);
81: return (this.set.Remove(entity));
82: }
83:
84: ObservableCollection<TEntity> IDbSet<TEntity>.Local
85: {
86: get { return (this.set.Local); }
87: }
88: #endregion
89:
90: #region IEnumerable<TEntity> Members
91: IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
92: {
93: return (this.set.Where(this.Filter).GetEnumerator());
94: }
95: #endregion
96:
97: #region IEnumerable Members
98: IEnumerator IEnumerable.GetEnumerator()
99: {
100: return ((this as IEnumerable<TEntity>).GetEnumerator());
101: }
102: #endregion
103:
104: #region IQueryable Members
105: Type IQueryable.ElementType
106: {
107: get { return ((this.set as IQueryable).ElementType); }
108: }
109:
110: Expression IQueryable.Expression
111: {
112: get
113: {
114: return (this.set.Where(this.Filter).Expression);
115: }
116: }
117:
118: IQueryProvider IQueryable.Provider
119: {
120: get
121: {
122: return ((this.set as IQueryable).Provider);
123: }
124: }
125: #endregion
126:
127: #region Private methods
128: private void ThrowIfEntityDoesNotMatchFilter(TEntity entity)
129: {
130: if ((entity != null) && (this.matchesFilter(entity) == false))
131: {
132: throw (new ArgumentException("Entity does not match filter", "entity"));
133: }
134: }
135:
136: #endregion
137: }
In the constructor of our DbContext, we create instances of this class, and pass a LINQ restriction query on its constructor:
1: public MyContext : DbContext
2: {
3: public MyContext()
4: {
5: this.Bases = new FilteredDbSet<Base>(this, x => x.SomeProperty == 1);
6: }
7:
8: public IDbSet<Base> Bases { get; protected set; }
9: }
From now on, all queries over the Bases collection will be restricted.
Entity Collections
A different matter is collections on entities. For these, we usually declare a property of ICollection<T> and let Entity Framework create an instance for us, when it is loading the entity. The class responsible for creating this instance is DbCollectionEntry, which unfortunately does not allow subclassing, because it doesn’t have any public or protected constructors or virtual methods. Let’s take a different path and create our own collection class instead:
1: [Serializable]
2: public class FilteredCollection<T> : ICollection<T>
3: {
4: private readonly DbCollectionEntry collectionEntry;
5: private readonly Func<T, Boolean> compiledFilter;
6: private ICollection<T> collection;
7:
8: public FilteredCollection(ICollection<T> collection, DbCollectionEntry collectionEntry, Expression<Func<T, Boolean>> filter)
9: {
10: this.Filter = filter;
11: this.collection = collection ?? new HashSet<T>();
12: this.collectionEntry = collectionEntry;
13: this.compiledFilter = filter.Compile();
14:
15: if (collection != null)
16: {
17: foreach (T entity in collection)
18: {
19: this.collection.Add(entity);
20: }
21:
22: this.collectionEntry.CurrentValue = this;
23: }
24: else
25: {
26: this.LoadIfNecessary();
27: }
28: }
29:
30: public Expression<Func<T, Boolean>> Filter
31: {
32: get;
33: private set;
34: }
35:
36: protected void ThrowIfInvalid(T entity)
37: {
38: if (this.compiledFilter(entity) == false)
39: {
40: throw (new ArgumentException("entity"));
41: }
42: }
43:
44: protected void LoadIfNecessary()
45: {
46: if (this.collectionEntry.IsLoaded == false)
47: {
48: IQueryable<T> query = this.collectionEntry.Query().Cast<T>().Where(this.Filter);
49:
50: this.collection = query.ToList();
51:
52: this.collectionEntry.CurrentValue = this;
53:
54: var _internalCollectionEntry = this.collectionEntry.GetType().GetField("_internalCollectionEntry", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(this.collectionEntry);
55: var _relatedEnd = _internalCollectionEntry.GetType().BaseType.GetField("_relatedEnd", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(_internalCollectionEntry);
56: _relatedEnd.GetType().GetField("_isLoaded", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(_relatedEnd, true);
57: }
58: }
59:
60: #region ICollection<T> Members
61:
62: void ICollection<T>.Add(T item)
63: {
64: this.LoadIfNecessary();
65: this.ThrowIfInvalid(item);
66: this.collection.Add(item);
67: }
68:
69: void ICollection<T>.Clear()
70: {
71: this.LoadIfNecessary();
72: this.collection.Clear();
73: }
74:
75: Boolean ICollection<T>.Contains(T item)
76: {
77: this.LoadIfNecessary();
78: return (this.collection.Contains(item));
79: }
80:
81: void ICollection<T>.CopyTo(T[] array, Int32 arrayIndex)
82: {
83: this.LoadIfNecessary();
84: this.collection.CopyTo(array, arrayIndex);
85: }
86:
87: Int32 ICollection<T>.Count
88: {
89: get
90: {
91: this.LoadIfNecessary();
92: return (this.collection.Count);
93: }
94: }
95:
96: Boolean ICollection<T>.IsReadOnly
97: {
98: get
99: {
100: this.LoadIfNecessary();
101: return (this.collection.IsReadOnly);
102: }
103: }
104:
105: Boolean ICollection<T>.Remove(T item)
106: {
107: this.LoadIfNecessary();
108: return (this.collection.Remove(item));
109: }
110:
111: #endregion
112:
113: #region IEnumerable<T> Members
114:
115: IEnumerator<T> IEnumerable<T>.GetEnumerator()
116: {
117: this.LoadIfNecessary();
118: return (this.collection.GetEnumerator());
119: }
120:
121: #endregion
122:
123: #region IEnumerable Members
124:
125: IEnumerator IEnumerable.GetEnumerator()
126: {
127: return ((this as IEnumerable<T>).GetEnumerator());
128: }
129:
130: #endregion
131: }
This collection receives a pointer to a possibly existing collection and a DbCollectionEntry responsible for loading this collection. We must use a bit of reflection magic to let DbCollectionEntry think that the collection was already loaded (IsLoaded), and instead load it ourselves, by applying our custom restriction to the expression returned by its Query method.
Now, in order to use this collection, we must intercept the ObjectMaterialized event of the underlying ObjectContext. We set up the filter through an extension method over DbContext:
1: public static class DbContextExtensions
2: {
3: public static void Filter<TContext, TParentEntity, TCollectionEntity>(this TContext context, Expression<Func<TContext, IDbSet<TParentEntity>>> path, Expression<Func<TParentEntity, ICollection<TCollectionEntity>>> collection, Expression<Func<TCollectionEntity, Boolean>> filter)
4: where TContext : DbContext
5: where TParentEntity : class, new()
6: where TCollectionEntity : class
7: {
8: (context as IObjectContextAdapter).ObjectContext.ObjectMaterialized += delegate(Object sender, ObjectMaterializedEventArgs e)
9: {
10: if (e.Entity is TParentEntity)
11: {
12: String navigationProperty = collection.ToString().Split('.')[1];
13: DbCollectionEntry col = context.Entry(e.Entity).Collection(navigationProperty);
14: col.CurrentValue = new FilteredCollection<TCollectionEntity>(null, col, filter);
15: }
16: };
17: }
18: }
The actual ICollection<T> is returned by setting a value to the DbCollectionEntry.CurrentValue property, and we set it to our collection
We use this extension method like this:
1: ctx.Filter(p => p.Products, p => p.Details, p => p.SomeProperty == 0);
This will pick the Product entity from the Products context collection and filter its Details collection.
Conclusion
As you can see, even though Entity Framework does not have all functionality that we might be used to, it still offers enough extensibility points that allow us to built it ourselves. The same technique that I presented here can be used for building lazy loaded or even IQueryable<T> collections, both interesting ideas that I leave as an exercise to you!