Arjan Egges replaces inheritance with decorators in Python state machines

Overview

The

typically solves the nightmare of massive if-else blocks and scattered boolean flags in complex logic flows. While traditional implementations rely heavily on
Object-Oriented Programming
principles like inheritance and polymorphism, this approach can lead to significant code duplication and fragmented logic. By refactoring the pattern into a data-driven, generic engine, we can make transitions explicit and centralized. This method treats a state machine for what it truly is: a lookup table mapping a current state and an event to a next state and a corresponding action.

Prerequisites

To follow this tutorial, you should be comfortable with basic

syntax. Specifically, familiarity with the following concepts is helpful:

  • Type Hinting: Understanding how to use generics (TypeVar).
  • Decorators: Knowing how functions can be wrapped to modify behavior.
  • Enums: Grouping related constants under a single type.
  • Data Classes: Using the dataclass decorator for efficient object creation.

Key Libraries & Tools

  • typing: Used for Generic, TypeVar, Callable, and Iterable to ensure type safety.
  • enum: Used to define distinct states (e.g., New, Authorized) and events (e.g., Authorize, Fail).
  • dataclasses: Simplifies the creation of the StateMachine and context objects.

Code Walkthrough

The core of this refactor is the StateMachine class. It uses generic types S (State), E (Event), and C (Context) to remain reusable across different domains like payments, logistics, or parsing.

from dataclasses import dataclass, field
from typing import Generic, TypeVar, Callable, Iterable, Dict, Tuple

S = TypeVar("S")
E = TypeVar("E")
C = TypeVar("C")

@dataclass
class StateMachine(Generic[S, E, C]):
    transitions: Dict[Tuple[S, E], Tuple[S, Callable[[C], None]]] = field(default_factory=dict)

    def add_transition(self, from_state: S, event: E, to_state: S, action: Callable[[C], None]):
        self.transitions[(from_state, event)] = (to_state, action)

    def transition(self, from_states: S | Iterable[S], event: E, to_state: S):
        def decorator(action: Callable[[C], None]):
            states = [from_states] if not isinstance(from_states, (list, tuple)) else from_states
            for s in states:
                self.add_transition(s, event, to_state, action)
            return action
        return decorator

    def handle(self, context: C, current_state: S, event: E) -> S:
        if (current_state, event) not in self.transitions:
            raise ValueError(f"Invalid transition from {current_state} with {event}")
        next_state, action = self.transitions[(current_state, event)]
        action(context)
        return next_state

In this implementation, the transition method acts as a decorator factory. It allows us to register state changes directly above the functions that perform the business logic. The handle method performs the lookup, executes the action (like logging or database updates), and returns the resulting state.

Syntax Notes

  • Generic Constraints: By defining S = TypeVar("S"), we ensure the state machine works with any type, but we can later restrict this to Enum types for better validation.
  • Decorator Chaining: The transition decorator returns the original action function, allowing the same logic (like a fail action) to be reused across multiple state transitions.
  • Type Union: Using S | Iterable[S] allows the decorator to accept either a single state or a collection of states, reducing boilerplate when multiple states share the same exit event.

Practical Examples

Imagine a payment flow. We define states and events using

. The business logic becomes declarative, appearing almost like a specification document.

pay_sm = StateMachine[PayState, PayEvent, PaymentContext]()

@pay_sm.transition(from_states=PayState.NEW, event=PayEvent.AUTHORIZE, to_state=PayState.AUTHORIZED)
def authorize_action(ctx: PaymentContext):
    ctx.audit.append(f"Authorized payment {ctx.id}")

@pay_sm.transition(from_states=(PayState.NEW, PayState.AUTHORIZED), event=PayEvent.FAIL, to_state=PayState.FAILED)
def fail_action(ctx: PaymentContext):
    ctx.audit.append("Transaction failed")

Tips & Gotchas

  • Keep State in the Object: Do not store the current state inside the StateMachine instance. Instead, keep it in the subject (e.g., the Payment class). This allows a single StateMachine engine to be shared across thousands of concurrent payment objects.
  • The Open-Closed Principle: This pattern excels at extension. To add a new transition, you simply write a new function with a decorator rather than modifying an existing class hierarchy.
  • Complex Internal State: If a specific state requires massive amounts of internal data that only matters during that phase, the traditional class-based approach might still be superior to avoid cluttering a single context object.
4 min read