Refactoring Data Science: Clean Architecture with Protocols and Function Composition

Overview

Data science projects often start as experimental scripts where speed of iteration outweighs software design. However, as models grow in complexity, these scripts become difficult to maintain and nearly impossible to reuse. This refactoring focuses on applying professional software engineering principles to a

based digit recognition project using the
MNIST
dataset. By implementing structural abstractions and functional patterns, we can transform a monolithic script into a modular, testable application that separates the concerns of data loading, experiment tracking, and model execution.

Prerequisites

Refactoring Data Science: Clean Architecture with Protocols and Function Composition
Refactoring A Data Science Project Part 1 - Abstraction and Composition

To follow this walkthrough, you should have a solid grasp of

syntax, particularly classes and decorators. Familiarity with
PyTorch
tensors and basic machine learning concepts (training loops, epochs, and metrics) is helpful. You should also understand the basics of type hinting, as we will use it to enforce data consistency throughout the refactor.

Key Libraries & Tools

  • PyTorch
    : A machine learning framework used here for building the neural network and handling data loaders.
  • TensorBoard
    : A visualization tool used to track experiment metrics like accuracy and loss.
  • NumPy
    &
    Pandas
    : Essential tools for data manipulation and numerical computation.
  • typing.Protocol: A
    Python
    feature used for structural subtyping to create flexible interfaces.
  • functools: A standard library used for high-order functions, specifically reduce for function composition.

Code Walkthrough: Structural Abstraction

One common mistake in data science code is tight coupling between the experiment logic and the tracking tool. Initially, the project used an

for tracking, but it still contained implementation details that forced the main script to depend on
TensorBoard
specifics.

Moving from ABCs to Protocols

We replaced the

with a
Python
. Protocols allow for "duck typing" with static type checking, meaning any class that implements the required methods automatically satisfies the interface without needing explicit inheritance.

from typing import Protocol
from enum import Enum, auto

class Stage(Enum):
    TRAIN = auto()
    TEST = auto()
    VAL = auto()

class ExperimentTracker(Protocol):
    def set_stage(self, stage: Stage) -> None: ...
    def add_batch_metric(self, name: str, value: float) -> None: ...
    def flush(self) -> None: ...

This change decouples our training loop from the storage backend. Whether we log to

, a CSV file, or a cloud service, the training code remains untouched.

The Problem with Variable Shadowing

A frequent pattern in

models is reassigning the same variable (often x) throughout the forward pass. While this saves memory, it makes debugging difficult because x represents a different state of data at every line.

Implementing Sequential Networks

To solve this, we use torch.nn.Sequential. This composes layers into a single pipeline, eliminating intermediate variables and making the data flow declarative.

# Before refactor: hard to track state
def forward(self, x):
    x = self.flatten(x)
    x = self.linear_relu_stack(x)
    return x

# After refactor: clean composition
self.network = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 512),
    nn.ReLU(),
    nn.Linear(512, 10)
)

def forward(self, x):
    return self.network(x)

Syntax Notes: Function Composition

If you aren't using a framework like

or
Scikit-learn
, you can still achieve clean pipelines using
Python
's functools.reduce. This is a powerful functional programming technique where you pass a value through a list of functions. We defined a compose function that takes multiple functions and returns a single callable:

def compose(*functions: Callable[[float], float]) -> Callable[[float], float]:
    return reduce(lambda f, g: lambda x: g(f(x)), functions)

This pattern turns f(g(h(x))) into a readable sequence, significantly reducing nested parentheses and improving maintainability.

Tips & Gotchas

  • Be Explicit with Types: Mixing Real numbers and float types in
    Python
    can lead to subtle bugs or annoying linter warnings. Stick to float for consistency across metrics and model weights.
  • Use Enums for States: Avoid using strings like "train" or "test" for experiment stages. Enums prevent typos and provide better IDE completion.
  • YAGNI (You Ain't Gonna Need It): Don't implement convenience methods in your abstract classes if they aren't currently used. Keep your interfaces lean and focused on what the application actually needs.
Refactoring Data Science: Clean Architecture with Protocols and Function Composition

Fancy watching it?

Watch the full video and context

4 min read