The code example in this topic is an implementation of an expression tree visitor. This class is designed to be inherited to create more specialized classes whose functionality requires traversing, examining or copying an expression tree.
The following topics use this class:
Both of these topics contain code that creates a specialized visitor subclass of the expression tree visitor base class presented in this topic.
In this expression tree visitor implementation, the Visit method, which should be called first, dispatches the expression it is passed to one of the more specialized visitor methods in the class, based on the type of the expression. The specialized visitor methods visit the sub-tree of the expression they are passed. If a sub-expression changes after it has been visited, for example by an overriding method in a derived class, the specialized visitor methods create a new expression that includes the changes in the sub-tree. Otherwise, they return the expression that they were passed. This recursive behavior enables a new expression tree to be built that either is the same as or a modified version of the original expression that was passed to Visit.
Public MustInherit Class ExpressionVisitor Protected Sub New() End Sub Protected Overridable Function Visit(ByVal exp As Expression) As Expression If exp Is Nothing Then Return exp End If Select Case exp.NodeType Case ExpressionType.Negate, _ ExpressionType.NegateChecked, _ ExpressionType.Not, _ ExpressionType.Convert, _ ExpressionType.ConvertChecked, _ ExpressionType.ArrayLength, _ ExpressionType.Quote, _ ExpressionType.TypeAs Return Me.VisitUnary(CType(exp, UnaryExpression)) Case ExpressionType.Add, _ ExpressionType.AddChecked, _ ExpressionType.Subtract, _ ExpressionType.SubtractChecked, _ ExpressionType.Multiply, _ ExpressionType.MultiplyChecked, _ ExpressionType.Divide, _ ExpressionType.Modulo, _ ExpressionType.And, _ ExpressionType.AndAlso, _ ExpressionType.Or, _ ExpressionType.OrElse, _ ExpressionType.LessThan, _ ExpressionType.LessThanOrEqual, _ ExpressionType.GreaterThan, _ ExpressionType.GreaterThanOrEqual, _ ExpressionType.Equal, _ ExpressionType.NotEqual, _ ExpressionType.Coalesce, _ ExpressionType.ArrayIndex, _ ExpressionType.RightShift, _ ExpressionType.LeftShift, _ ExpressionType.ExclusiveOr Return Me.VisitBinary(CType(exp, BinaryExpression)) Case ExpressionType.TypeIs Return Me.VisitTypeIs(CType(exp, TypeBinaryExpression)) Case ExpressionType.Conditional Return Me.VisitConditional(CType(exp, ConditionalExpression)) Case ExpressionType.Constant Return Me.VisitConstant(CType(exp, ConstantExpression)) Case ExpressionType.Parameter Return Me.VisitParameter(CType(exp, ParameterExpression)) Case ExpressionType.MemberAccess Return Me.VisitMemberAccess(CType(exp, MemberExpression)) Case ExpressionType.Call Return Me.VisitMethodCall(CType(exp, MethodCallExpression)) Case ExpressionType.Lambda Return Me.VisitLambda(CType(exp, LambdaExpression)) Case ExpressionType.New Return Me.VisitNew(CType(exp, NewExpression)) Case ExpressionType.NewArrayInit, _ ExpressionType.NewArrayBounds Return Me.VisitNewArray(CType(exp, NewArrayExpression)) Case ExpressionType.Invoke Return Me.VisitInvocation(CType(exp, InvocationExpression)) Case ExpressionType.MemberInit Return Me.VisitMemberInit(CType(exp, MemberInitExpression)) Case ExpressionType.ListInit Return Me.VisitListInit(CType(exp, ListInitExpression)) Case Else Throw New Exception("Unhandled expression type: '" & exp.NodeType & "'") End Select End Function Protected Overridable Function VisitBinding(ByVal binding As MemberBinding) As MemberBinding Select Case binding.BindingType Case MemberBindingType.Assignment Return Me.VisitMemberAssignment(CType(binding, MemberAssignment)) Case MemberBindingType.MemberBinding Return Me.VisitMemberMemberBinding(CType(binding, MemberMemberBinding)) Case MemberBindingType.ListBinding Return Me.VisitMemberListBinding(CType(binding, MemberListBinding)) Case Else Throw New Exception("Unhandled binding type '" & binding.BindingType & "'") End Select End Function Protected Overridable Function VisitElementInitializer(ByVal initializer As ElementInit) _ As ElementInit Dim arguments = Me.VisitExpressionList(initializer.Arguments) If arguments IsNot initializer.Arguments Then Return Expression.ElementInit(initializer.AddMethod, arguments) End If Return initializer End Function Protected Overridable Function VisitUnary(ByVal u As UnaryExpression) As Expression Dim operand = Me.Visit(u.Operand) If operand IsNot u.Operand Then Return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method) End If Return u End Function Protected Overridable Function VisitBinary(ByVal b As BinaryExpression) As Expression Dim left = Me.Visit(b.Left) Dim right = Me.Visit(b.Right) Dim conversion = Me.Visit(b.Conversion) If left IsNot b.Left Or right IsNot b.Right Or conversion IsNot b.Conversion Then If b.NodeType = ExpressionType.Coalesce And b.Conversion IsNot Nothing Then Return Expression.Coalesce(left, right, _ TryCast(conversion, LambdaExpression)) Else Return Expression.MakeBinary(b.NodeType, left, right, _ b.IsLiftedToNull, b.Method) End If End If Return b End Function Protected Overridable Function VisitTypeIs(ByVal b As TypeBinaryExpression) As Expression Dim expr = Me.Visit(b.Expression) If expr IsNot b.Expression Then Return Expression.TypeIs(expr, b.TypeOperand) End If Return b End Function Protected Overridable Function VisitConstant(ByVal c As ConstantExpression) As Expression Return c End Function Protected Overridable Function VisitConditional(ByVal c As ConditionalExpression) As Expression Dim test = Me.Visit(c.Test) Dim ifTrue = Me.Visit(c.IfTrue) Dim ifFalse = Me.Visit(c.IfFalse) If test IsNot c.Test Or ifTrue IsNot c.IfTrue Or ifFalse IsNot c.IfFalse Then Return Expression.Condition(test, ifTrue, ifFalse) End If Return c End Function Protected Overridable Function VisitParameter(ByVal p As ParameterExpression) As Expression Return p End Function Protected Overridable Function VisitMemberAccess(ByVal m As MemberExpression) As Expression Dim exp = Me.Visit(m.Expression) If exp IsNot m.Expression Then Return Expression.MakeMemberAccess(exp, m.Member) End If Return m End Function Protected Overridable Function VisitMethodCall(ByVal m As MethodCallExpression) As Expression Dim obj = Me.Visit(m.Object) Dim args = Me.VisitExpressionList(m.Arguments) If obj IsNot m.Object Or args IsNot m.Arguments Then Return Expression.Call(obj, m.Method, args) End If Return m End Function Protected Overridable Function VisitExpressionList( _ ByVal original As ReadOnlyCollection(Of Expression)) As ReadOnlyCollection(Of Expression) Dim list As List(Of Expression) = Nothing Dim n = original.Count For i = 0 To n - 1 Dim p = Me.Visit(original(i)) If list IsNot Nothing Then list.Add(p) ElseIf p IsNot original(i) Then list = New List(Of Expression)(n) For j = 0 To i - 1 list.Add(original(j)) Next j list.Add(p) End If Next i If list IsNot Nothing Then Return list.AsReadOnly() End If Return original End Function Protected Overridable Function VisitMemberAssignment(ByVal assignment As MemberAssignment) _ As MemberAssignment Dim e = Me.Visit(assignment.Expression) If e IsNot assignment.Expression Then Return Expression.Bind(assignment.Member, e) End If Return assignment End Function Protected Overridable Function VisitMemberMemberBinding(ByVal binding As MemberMemberBinding) _ As MemberMemberBinding Dim bindings = Me.VisitBindingList(binding.Bindings) If bindings IsNot binding.Bindings Then Return Expression.MemberBind(binding.Member, bindings) End If Return binding End Function Protected Overridable Function VisitMemberListBinding(ByVal binding As MemberListBinding) _ As MemberListBinding Dim initializers = Me.VisitElementInitializerList(binding.Initializers) If initializers IsNot binding.Initializers Then Return Expression.ListBind(binding.Member, initializers) End If Return binding End Function Protected Overridable Function VisitBindingList( _ ByVal original As ReadOnlyCollection(Of MemberBinding)) As IEnumerable(Of MemberBinding) Dim list As List(Of MemberBinding) = Nothing Dim n = original.Count For i = 0 To n - 1 Dim b = Me.VisitBinding(original(i)) If list IsNot Nothing Then list.Add(b) ElseIf b IsNot original(i) Then list = New List(Of MemberBinding)(n) For j = 0 To i - 1 list.Add(original(j)) Next j list.Add(b) End If Next i If list IsNot Nothing Then Return list End If Return original End Function Protected Overridable Function VisitElementInitializerList( _ ByVal original As ReadOnlyCollection(Of ElementInit)) As IEnumerable(Of ElementInit) Dim list As List(Of ElementInit) = Nothing Dim n = original.Count For i = 0 To n - 1 Dim init = Me.VisitElementInitializer(original(i)) If list IsNot Nothing Then list.Add(init) ElseIf init IsNot original(i) Then list = New List(Of ElementInit)(n) For j = 0 To i - 1 list.Add(original(j)) Next j list.Add(init) End If Next i If list IsNot Nothing Then Return list End If Return original End Function Protected Overridable Function VisitLambda(ByVal lambda As LambdaExpression) As Expression Dim body = Me.Visit(lambda.Body) If body IsNot lambda.Body Then Return Expression.Lambda(lambda.Type, body, lambda.Parameters) End If Return lambda End Function Protected Overridable Function VisitNew(ByVal nex As NewExpression) As NewExpression Dim args = Me.VisitExpressionList(nex.Arguments) If args IsNot nex.Arguments Then If nex.Members IsNot Nothing Then Return Expression.[New](nex.Constructor, args, nex.Members) Else Return Expression.[New](nex.Constructor, args) End If End If Return nex End Function Protected Overridable Function VisitMemberInit(ByVal init As MemberInitExpression) As Expression Dim n = Me.VisitNew(init.NewExpression) Dim bindings = Me.VisitBindingList(init.Bindings) If n IsNot init.NewExpression Or bindings IsNot init.Bindings Then Return Expression.MemberInit(n, bindings) End If Return init End Function Protected Overridable Function VisitListInit(ByVal init As ListInitExpression) As Expression Dim n = Me.VisitNew(init.NewExpression) Dim initializers = Me.VisitElementInitializerList(init.Initializers) If n IsNot init.NewExpression Or initializers IsNot init.Initializers Then Return Expression.ListInit(n, initializers) End If Return init End Function Protected Overridable Function VisitNewArray(ByVal na As NewArrayExpression) As Expression Dim exprs = Me.VisitExpressionList(na.Expressions) If exprs IsNot na.Expressions Then If na.NodeType = ExpressionType.NewArrayInit Then Return Expression.NewArrayInit(na.Type.GetElementType(), exprs) Else Return Expression.NewArrayBounds(na.Type.GetElementType(), exprs) End If End If Return na End Function Protected Overridable Function VisitInvocation(ByVal iv As InvocationExpression) As Expression Dim args = Me.VisitExpressionList(iv.Arguments) Dim expr = Me.Visit(iv.Expression) If args IsNot iv.Arguments Or expr IsNot iv.Expression Then Return Expression.Invoke(expr, args) End If Return iv End Function End Class
public abstract class ExpressionVisitor { protected ExpressionVisitor() { } protected virtual Expression Visit(Expression exp) { if (exp == null) return exp; switch (exp.NodeType) { case ExpressionType.Negate: case ExpressionType.NegateChecked: case ExpressionType.Not: case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.ArrayLength: case ExpressionType.Quote: case ExpressionType.TypeAs: return this.VisitUnary((UnaryExpression)exp); case ExpressionType.Add: case ExpressionType.AddChecked: case ExpressionType.Subtract: case ExpressionType.SubtractChecked: case ExpressionType.Multiply: case ExpressionType.MultiplyChecked: case ExpressionType.Divide: case ExpressionType.Modulo: case ExpressionType.And: case ExpressionType.AndAlso: case ExpressionType.Or: case ExpressionType.OrElse: case ExpressionType.LessThan: case ExpressionType.LessThanOrEqual: case ExpressionType.GreaterThan: case ExpressionType.GreaterThanOrEqual: case ExpressionType.Equal: case ExpressionType.NotEqual: case ExpressionType.Coalesce: case ExpressionType.ArrayIndex: case ExpressionType.RightShift: case ExpressionType.LeftShift: case ExpressionType.ExclusiveOr: return this.VisitBinary((BinaryExpression)exp); case ExpressionType.TypeIs: return this.VisitTypeIs((TypeBinaryExpression)exp); case ExpressionType.Conditional: return this.VisitConditional((ConditionalExpression)exp); case ExpressionType.Constant: return this.VisitConstant((ConstantExpression)exp); case ExpressionType.Parameter: return this.VisitParameter((ParameterExpression)exp); case ExpressionType.MemberAccess: return this.VisitMemberAccess((MemberExpression)exp); case ExpressionType.Call: return this.VisitMethodCall((MethodCallExpression)exp); case ExpressionType.Lambda: return this.VisitLambda((LambdaExpression)exp); case ExpressionType.New: return this.VisitNew((NewExpression)exp); case ExpressionType.NewArrayInit: case ExpressionType.NewArrayBounds: return this.VisitNewArray((NewArrayExpression)exp); case ExpressionType.Invoke: return this.VisitInvocation((InvocationExpression)exp); case ExpressionType.MemberInit: return this.VisitMemberInit((MemberInitExpression)exp); case ExpressionType.ListInit: return this.VisitListInit((ListInitExpression)exp); default: throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType)); } } protected virtual MemberBinding VisitBinding(MemberBinding binding) { switch (binding.BindingType) { case MemberBindingType.Assignment: return this.VisitMemberAssignment((MemberAssignment)binding); case MemberBindingType.MemberBinding: return this.VisitMemberMemberBinding((MemberMemberBinding)binding); case MemberBindingType.ListBinding: return this.VisitMemberListBinding((MemberListBinding)binding); default: throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType)); } } protected virtual ElementInit VisitElementInitializer(ElementInit initializer) { ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments); if (arguments != initializer.Arguments) { return Expression.ElementInit(initializer.AddMethod, arguments); } return initializer; } protected virtual Expression VisitUnary(UnaryExpression u) { Expression operand = this.Visit(u.Operand); if (operand != u.Operand) { return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method); } return u; } protected virtual Expression VisitBinary(BinaryExpression b) { Expression left = this.Visit(b.Left); Expression right = this.Visit(b.Right); Expression conversion = this.Visit(b.Conversion); if (left != b.Left || right != b.Right || conversion != b.Conversion) { if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null) return Expression.Coalesce(left, right, conversion as LambdaExpression); else return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method); } return b; } protected virtual Expression VisitTypeIs(TypeBinaryExpression b) { Expression expr = this.Visit(b.Expression); if (expr != b.Expression) { return Expression.TypeIs(expr, b.TypeOperand); } return b; } protected virtual Expression VisitConstant(ConstantExpression c) { return c; } protected virtual Expression VisitConditional(ConditionalExpression c) { Expression test = this.Visit(c.Test); Expression ifTrue = this.Visit(c.IfTrue); Expression ifFalse = this.Visit(c.IfFalse); if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) { return Expression.Condition(test, ifTrue, ifFalse); } return c; } protected virtual Expression VisitParameter(ParameterExpression p) { return p; } protected virtual Expression VisitMemberAccess(MemberExpression m) { Expression exp = this.Visit(m.Expression); if (exp != m.Expression) { return Expression.MakeMemberAccess(exp, m.Member); } return m; } protected virtual Expression VisitMethodCall(MethodCallExpression m) { Expression obj = this.Visit(m.Object); IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments); if (obj != m.Object || args != m.Arguments) { return Expression.Call(obj, m.Method, args); } return m; } protected virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original) { List<Expression> list = null; for (int i = 0, n = original.Count; i < n; i++) { Expression p = this.Visit(original[i]); if (list != null) { list.Add(p); } else if (p != original[i]) { list = new List<Expression>(n); for (int j = 0; j < i; j++) { list.Add(original[j]); } list.Add(p); } } if (list != null) { return list.AsReadOnly(); } return original; } protected virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) { Expression e = this.Visit(assignment.Expression); if (e != assignment.Expression) { return Expression.Bind(assignment.Member, e); } return assignment; } protected virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) { IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings); if (bindings != binding.Bindings) { return Expression.MemberBind(binding.Member, bindings); } return binding; } protected virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) { IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers); if (initializers != binding.Initializers) { return Expression.ListBind(binding.Member, initializers); } return binding; } protected virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original) { List<MemberBinding> list = null; for (int i = 0, n = original.Count; i < n; i++) { MemberBinding b = this.VisitBinding(original[i]); if (list != null) { list.Add(b); } else if (b != original[i]) { list = new List<MemberBinding>(n); for (int j = 0; j < i; j++) { list.Add(original[j]); } list.Add(b); } } if (list != null) return list; return original; } protected virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original) { List<ElementInit> list = null; for (int i = 0, n = original.Count; i < n; i++) { ElementInit init = this.VisitElementInitializer(original[i]); if (list != null) { list.Add(init); } else if (init != original[i]) { list = new List<ElementInit>(n); for (int j = 0; j < i; j++) { list.Add(original[j]); } list.Add(init); } } if (list != null) return list; return original; } protected virtual Expression VisitLambda(LambdaExpression lambda) { Expression body = this.Visit(lambda.Body); if (body != lambda.Body) { return Expression.Lambda(lambda.Type, body, lambda.Parameters); } return lambda; } protected virtual NewExpression VisitNew(NewExpression nex) { IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments); if (args != nex.Arguments) { if (nex.Members != null) return Expression.New(nex.Constructor, args, nex.Members); else return Expression.New(nex.Constructor, args); } return nex; } protected virtual Expression VisitMemberInit(MemberInitExpression init) { NewExpression n = this.VisitNew(init.NewExpression); IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings); if (n != init.NewExpression || bindings != init.Bindings) { return Expression.MemberInit(n, bindings); } return init; } protected virtual Expression VisitListInit(ListInitExpression init) { NewExpression n = this.VisitNew(init.NewExpression); IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers); if (n != init.NewExpression || initializers != init.Initializers) { return Expression.ListInit(n, initializers); } return init; } protected virtual Expression VisitNewArray(NewArrayExpression na) { IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions); if (exprs != na.Expressions) { if (na.NodeType == ExpressionType.NewArrayInit) { return Expression.NewArrayInit(na.Type.GetElementType(), exprs); } else { return Expression.NewArrayBounds(na.Type.GetElementType(), exprs); } } return na; } protected virtual Expression VisitInvocation(InvocationExpression iv) { IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments); Expression expr = this.Visit(iv.Expression); if (args != iv.Arguments || expr != iv.Expression) { return Expression.Invoke(expr, args); } return iv; } }
Note:
|
|---|
|
In this implementation, the Visit method, which is the starting point for visiting an expression tree, has the protected (Protected in Visual Basic) access modifier. This means that in order to make this method accessible from outside the class or its derived classes, you must create a public (Public in Visual Basic) method that calls Visit. By having this one public (Public in Visual Basic) method in your visitor, the entry point is more obvious to callers. |
-
Add a reference to System.Core.dll if it is not already referenced in your project.
-
Add using directives (or Imports statements in Visual Basic) for the System.Collections.Generic, System.Collections.ObjectModel, and System.Linq.Expressions namespaces.
Note: