State Machines With Enumerations

I have written a simple set of classes (two, actually) that can be used to implement a simple state machine over an enumerated type (enum).

You first apply a custom attribute to individual enumeration values, then you can check if a given transition is allowed or not.

Let's look at the code:

 

using System;

using System.Collections.Generic;

using System.Text;

namespace StateMachine

{

  [Serializable]

  [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = false)]

  public sealed class StateAttribute: Attribute

  {

      #region Private fields

      private Enum[] nextStates = null;

      private Boolean final = false;

      private Boolean initial = false;

      #endregion

      #region Public constructors

      public StateAttribute(Boolean initial, Boolean final, params Object[] nextStates)

      {

          Debug.Assert(nextStates != null, "nextStates is null");

          List<Enum> list = new List<Enum>();

          foreach (Object nextState in nextStates)

          {

              Debug.Assert(nextState is Enum, "value is not an enumerated type");

              list.Add((Enum)nextState);

          }

          this.Initial = initial;

          this.Final = final;

          this.nextStates = list.ToArray();

      }

      public StateAttribute(params Object[] nextStates): this(false, false, nextStates)

      {

      }

      #endregion

      #region Public properties

      public Boolean Initial

      {

          get

          {

              return (this.initial);

          }

          set

          {

              this.initial = value;

          }

      }

      public Boolean Final

      {

          get

          {

              return (this.final);

          }

          set

          {

              this.final = value;

          }

      }

      public Enum[] NextStates

      {

          get

          {

              return (this.nextStates);

          }

      }

      #endregion

      #region Public override methods

      public override Boolean Equals(Object obj)

      {

          if (!(obj is StateAttribute))

          {

              return (false);

          }

          if ((Object) this == obj)

          {

              return (true);

          }

          StateAttribute other = obj as StateAttribute;

          if (other.nextStates.Length != this.nextStates.Length)

          {

              return (false);

          }

          foreach (Enum e in this.nextStates)

          {

              if (Array.IndexOf(other.nextStates, e) < 0)

              {

                  return (false);

              }

          }

          return((this.initial == other.initial) && (this.final == other.final));

      }

      #endregion

      }

 

    public static class StateMachine

    {

        #region Public static methods

        public static Boolean HasStateMachine(Type enumType)

        {

            return (GetStates(enumType).Length != 0);

        }

        public static Enum[] GetInitialStates(Type enumType)

        {

            Debug.Assert(enumType != null, "enumType is null");

            List<Enum> states = new List<Enum>();

            foreach (Enum state in GetStates(enumType))

            {

                FieldInfo fi = enumType.GetField(state.ToString());

                 Debug.Assert(fi!= null, "Field not found");

                StateAttribute s = Attribute.GetCustomAttribute(fi, typeof(StateAttribute)) as StateAttribute;

                if ((s != null) && (s.Initial == true))

                {

                    states.Add(state);

                }

          }

          return (states.ToArray());

        }

        public static Enum[] GetFinalStates(Type enumType)

        {

            Debug.Assert(enumType != null, "enumType is null");

            List<Enum> states = new List<Enum>();

            foreach (Enum state in GetStates(enumType))

            {

                FieldInfo fi = enumType.GetField(state.ToString());

                 Debug.Assert(fi!= null, "Field is null");

                StateAttribute s = Attribute.GetCustomAttribute(fi, typeof(StateAttribute)) as StateAttribute;

                if ((s != null) && (s.Final == true))

                {

                    states.Add(state);

                }

            }

            return (states.ToArray());

        }

        public static Enum[] GetStates(Type enumType)

        {

            Debug.Assert(enumType != null, "enumType is null");

             Debug.Assert(enumType.IsEnum == true, "enumType is not an enum");

            List<Enum> states = new List<Enum>();

            foreach (FieldInfo fi in enumType.GetFields())

            {

                StateAttribute s = Attribute.GetCustomAttribute(fi, typeof(StateAttribute)) as StateAttribute;

                if (s != null)

                {

                    states.Add((
Enum) fi.GetValue(null));

                }

            }

            return (states.ToArray());

        }

        public static Boolean CanTransition(Enum initialState, Enum finalState)

        {

            Debug.Assert(initialState.GetType() == finalState.GetType(), "states are not of the same type");

            return (CanTransition(initialState, finalState, new List<Enum>()));

        }

        public static Enum[] NextStates(Enum state)

        {

            FieldInfo fi = state.GetType().GetField(state.ToString());

            Debug.Assert(fi != null, "Field not found");

            StateAttribute s = Attribute.GetCustomAttribute(fi, typeof(StateAttribute)) as StateAttribute;

            Debug.Assert(s != null, "Attribute not found");  

            List<Enum> nextStates = new List<Enum>();

            foreach (Object nextState in s.NextStates)

            {

                if (nextState is Enum)

                {

                    if (nextStates.Contains((Enum) nextState) == false)

                    {

                        nextStates.Add((
Enum) nextState);

                    }

                }

            }

            return (nextStates.ToArray());

        }

        public static Enum[] PreviousStates(Enum state)

        {

            List<Enum> states = new List<Enum>();

            foreach (Enum s in GetStates(state.GetType()))

            {

                if (Array.IndexOf(NextStates(s), state) >= 0)

                {

                    if (states.Contains(s) == false)

                    {

                        states.Add(s);

                    }

                }

            }

            return(states.ToArray());

        }

        public static Boolean IsFinal(Enum state)

        {

            FieldInfo fi = state.GetType().GetField(state.ToString());

            Debug.Assert(fi != null, "Field not found");

            StateAttribute s = Attribute.GetCustomAttribute(fi, typeof(StateAttribute)) as StateAttribute;

            return ((s != null) && (s.Final == true));

        }

        public static Boolean IsInitial(Enum state)

        {

            FieldInfo fi = state.GetType().GetField(state.ToString());

            Debug.Assert(fi != null, "Field not found");

            StateAttribute s = Attribute.GetCustomAttribute(fi, typeof(StateAttribute)) as StateAttribute;

            return ((s != null) && (s.Initial == true));

        }

        #endregion

        #region Private static methods

        private static Boolean CanTransition(Enum initialState, Enum finalState, List<Enum> processedStates)

        {

           foreach (Enum state in NextStates(initialState))

           {

               if (processedStates.Contains(state) == true)

               {

                   continue;

               }

               processedStates.Add(state);

               if (state.Equals(finalState) == true)

               {

                   return (true);

               }

                return (CanTransition(state, finalState, processedStates));

            }

            return (false);

        }

        #endregion

        }

}

 

And here's a quick sample:

 

public enum State

{

    [
State(true, false, B1, B2)]

    A,

    [
State(C)]

    B1,

    [
State(D1)]

    B2,

    [
State(D2)]

    C,

    [
State(E)]

    D1,

    [
State(E)]

    D2,

    [
State(Initial = false, Final = true)]

    E

}

 

Boolean t1 = StateMachine.CanTransition(State.A, State.B1); //true

Boolean t2 = StateMachine.CanTransition(State.D1, State.A); //false

Boolean sm = StateMachine.HasStateMachine(typeof(State)); //true

Boolean f = StateMachine.IsFinal(State.E); //true

Enum[] initialStates = StateMachine.GetInitialStates(typeof(State)); //A

Enum[] secondLevelStates = StateMachine.NextStates(State.A); //B1, B2

Enum[] finalStates = StateMachine.GetFinalStates(typeof(State)); //D 

                             

No Comments