The callback pattern

chunk
code
Author

Akshay Shankar

Published

January 27, 2025

One of my favourite design patterns is the Callback (also known as the Command pattern in object-oriented contexts). Loosely speaking, a callback is simply a function f whose reference has been passed on to another function g which then proceeds to invoke f upon completion at a later time. It is trivially usable in languages that implement functions as first-class citizens (where they may be passed around as arguments with no hassle). While the idea is quite simple, it can lend itself to some very powerful and intuitive user interfaces! I’ve had natural use-cases arise both during my day-to-day research as well as during game development.

Let us consider a simple and perhaps a bit over-engineered example using Julia to demonstrate this idea. Say we want to write a generic interface for a differential equation solver. (Note that much of what I present here is simply a poor man’s version of the equivalent implementation in the behemoth that is DifferentialEquations.jl.)

A generic integrator interface

RK4 seems like a good place to start. We begin by defining a struct to hold the state of the integrator, (u, tspan), where u can be anything that implements similar and supports basic algebraic operations and broadcasting, and tspan is an AbstractRange specifying the time interval of problem. The remaining variables are dummies to avoid allocation in a hot loop.

abstract type AbstractIntegrator end

struct RK4Integrator{T,S} <: AbstractIntegrator
    # current solver state
    u::T
    tspan::S

    # intermediate variables
    k1::T
    k2::T
    k3::T
    k4::T
    tmp::T

    RK4Integrator(u, tspan) = new{typeof(u),typeof(tspan)}(
        u, tspan, similar(u), similar(u), similar(u), similar(u), similar(u)
    )
end

Every integrator is expected to implement a step! method that expects a function f! implementing the in-place derivative, and performs the actual time-stepping.

function step!(integrator::RK4Integrator, f!, t)
    (; u, tspan, k1, k2, k3, k4, tmp) = integrator
    dt = step(tspan)

    f!(k1, u, t)

    @. tmp = u + dt / 2 * k1
    f!(k2, tmp, t + dt / 2)

    @. tmp = u + dt / 2 * k2
    f!(k3, tmp, t + dt / 2)

    @. tmp = u + dt * k3
    f!(k4, tmp, t + dt)

    @. u += dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return u
end
step! (generic function with 1 method)

Then, we require only one generic function that actually loops through the time-steps. Below we just implement a basic version, but there is nothing stopping us from being more sophisticated with adaptive algorithms as well.

function solve!(f!, u0, tspan; solver, (callback!)=(iter, integrator) -> nothing)
    integrator = solver(u0, tspan)

    for iter in eachindex(tspan)
        step!(integrator, f!, tspan[iter])
        callback!(iter, integrator)
    end
end
solve! (generic function with 1 method)

At this point, we introduce the notion of a callback as a mutating function that takes the input (iter, integrator)::(Integer, AbstractIntegrator) and is invoked at the end of every iteration. We will soon see that the callback can be utilized by the user to run custom logic within the integrator loop without ever having to touch the actual internals. In the meanwhile, we can now solve any ordinary differential equation with the RK4 method!

    function dfdt!(du, u, t)
        du[1] = -5. * u[1]
    end

    solve!(dfdt!, [5.], range(0., 1., length=100), solver = RK4Integrator)

Note that this function only provides access to the value of u at the last time-step, which is a bit weird since typically we would want the evolution of the state as a time-series. While this was an oversight on my part, it also provides a good opportunity to utilize callbacks to store custom data during the solver steps. In order to do this, we may simply create a callback function like so. (Note that it is strictly necessary to wrap the keyword argument callback! in parenthesis because there is an ambiguity in syntax due to a possible != otherwise.)

# we construct the callback inside another function to create a closure over `data` instead of leaving it as a global variable (which would degrade performance!)
function create_saving_callback()
    data = []
    saving_callback = (iter, integrator) -> push!(data, integrator.u[1])
    return data, saving_callback
end

data, saving_callback = create_saving_callback()
solve!(dfdt!, [5.], range(0., 1., length=100), solver = RK4Integrator, (callback!)=saving_callback)
plot(data)

This works perfectly fine and is exactly what one would (should?) do for personal code! However, this is not a very modular approach. If the goal is to build a package with an intuitive and flexible user-facing API, we can build upon this a lot more to account for common use-cases. Let us see a possible way to achieve this.

A small aside

Generally, custom logic can be stateful (i.e, have persistent local variables) and one would need to create a closure over the function that actually performs the mutating action on the state of the integrator (just like we did above). However, Julia offers another alternative; namely, we can define a struct that encapsulates the data, which can then be invoked as a function with access to this data. Since both structs and functions may be invoked by means of a function call syntax, I will generically refer to them as callables henceforth.

The callback struct

In order to define a generic interface, we first need to think about what the general use-case of callbacks would be. In the context of an integrator, we expect that callbacks will always have the specific form of evaluating whether a certain condition is met at the time of invocation, and if so, it performs a certain effect that may mutate the state of the integrator. For example, we may want to compute and save a certain quantity every iteration, or normalize the state whenever it deviates beyond a certain threshold, or apply a periodic perturbation every 10 iterations, etc. So, we define a Callback struct as a collection of two callables; (1) condition with signature (iter, integrator) -> bool and (2) effect with signature (iter, integrator) -> integrator, although it can be mutating as well.

begin
    struct Callback{C,E}
        "condition for performing callback: (iter, integrator) -> bool"
        condition::C
        "callback function acting on solver state: (iter, integrator) -> integrator"
        effect::E
    end
    
    function (p::Callback)(iter, integrator)
        if p.condition(iter, integrator)
            p.effect(iter, integrator)
        end
    
        return integrator
    end
end

We can further define a CallbackList that sequentially invokes its elements if we have more than one callback.

begin
    struct CallbackList
        callbacks::Vector{Callback}
    end

    function (p::CallbackList)(iter, integrator)
        for callback in p.callbacks
            callback(iter, integrator)
        end

        return integrator
    end

    Base.getindex(p::CallbackList, idx) = getindex(p.callbacks, idx)
    Base.length(p::CallbackList) = length(p.callbacks)
end

Some common conditions and effects

Now that the basic structure is in place, let us implement some conditions that may be required quite often. Again, these don’t have to be structs, but since the specific use-cases here require statefulness, they are an appropriate choice. Note that we have not placed explicit safegaurds to statically check whether the function call has the right signature, so the program would fail at run-time if the signature does not match what is expected.

begin
    # trigger every n iterations
    mutable struct OnIterElapsed
        "number of iterations between trigger"
        save_freq::Int

        loop::Bool # if true, continuously fires, otherwise it is a one-shot condition
        flag::Bool # true if condition has been fired once

        OnIterElapsed(save_freq, loop=true) = new(save_freq, loop, false)
    end

    function (p::OnIterElapsed)(iter, integrator)
        res = iszero(iter % p.save_freq)
        return p.loop ? res : (!p.flag ? (p.flag = res; p.flag) : false)
    end

    # maybe for saving data to file as a backup during long program runs
    mutable struct OnRealTimeElapsed
        "starting time in seconds"
        start_tick::Float64
        "number of seconds between trigger"
        save_freq::Float64

        # :s - second, :m - minute, :h - hour
        function OnRealTimeElapsed(freq, unit=:m)
            if unit == :m
                freq *= 60
            elseif unit == :h
                freq *= 3600
            elseif unit != :s
                throw(ArgumentError("invalid unit `:$unit`, expected `:s`, `:m` or `:h`"))
            end

            new(time(), freq)
        end
    end

    # only gets called AFTER an iteration is complete!
    (p::OnRealTimeElapsed)(iter, state, H, envs) = ((time() - p.start_tick) > p.save_freq) ? (p.start_tick = time(); true) : false
end

We now define a generic effect that simply computes some specified observables using the integrator state. We do so by defining a struct RecordObservable which expects an input recipe which is a named tuple containing a collection of (observable_name = function_to_compute_observable). For example; (norm = (iter, integrator) -> norm(integrator.u), iter = (iter, integrator) -> iter). It then stores the result in data whenever the condition of its parent Callback returns true. While this is a trivial example, it may be quite useful if there is some non-trivial internal integrator state that must be tracked.

begin
    struct RecordObservable{D,O}
        "collection of string-array pairs containing observable data"
        data::D
        "functions to compute observable data"
        observables::O

        function RecordObservable(recipe)
            names = keys(recipe)
            observables = values(recipe)
            data = NamedTuple{names}(([] for _ in eachindex(names)))
            return new{typeof(data),typeof(observables)}(data, observables)
        end
    end

    # to access the data without an extra .data; bit iffy, but functional
    function Base.getproperty(p::RecordObservable, key::Symbol)
        if key in fieldnames(typeof(p))
            return getfield(p, key)
        else
            return getproperty(p.data, key)
        end
    end

    Base.length(p::RecordObservable) = length(p.data)

    function (p::RecordObservable)(iter, integrator)
        for i in 1:length(p)
            push!(p.data[i], p.observables[i](iter, integrator))
        end

        return integrator
    end
end

Having done this, we can now once again visualize the solution as a time-series at every iteration.

let
    record_state = Callback(
        OnIterElapsed(1),
        RecordObservable((u=(iter, integrator) -> integrator.u[1],))
    )
    solve!(dfdt!, [5.0], range(0.0, 1.0, length=100), solver=RK4Integrator, (callback!)=record_state)
    plot(record_state.effect.u)
end

Perfect! While it may seem that we introduced a ton of machinery for no reason, we have effectively removed the boilerplate for common tasks such as the condition-effect structure and data saving. The work we put into this allows us to write clean and intuitive code when we want to do more things within a callback. Perhaps a realistic use-case is to specify dynamical conditions that kick in at some intermediate time, for example, a random kick every n iterations. We could also constrain the solution from going below a certain threshold.

let
    record_state = Callback(
        OnIterElapsed(1),
        RecordObservable((u=(iter, integrator) -> integrator.u[1],))
    )

    hardwall_callback = Callback(
        (iter, integrator) -> integrator.u[1] < 1.,
        (iter, integrator) -> integrator.u[1] = 1.,
    )

    kick_callback = Callback(
        OnIterElapsed(30),
        (iter, integrator) -> integrator.u[1] += rand(),
    )

    solve!(dfdt!, [5.0], range(0.0, 1.0, length=100), solver=RK4Integrator, (callback!)=CallbackList([record_state, hardwall_callback, kick_callback]))
    plot(record_state.effect.u, ylims = [0., 5.1])
end

I think this is quite a nice example demonstrating how the callback pattern allows modularity and extensibility for the user with no need to poke into the internals of the core solver loop. However, it should be noted that design patterns such as this one tend to quickly ramp up in complexity and the overhead introduced both in compilation time and developer maintanence time can often outweigh its usefulness. So its important to keep your specific use-case in mind and try to use such constructions only when strictly required.

Before we conclude, it is important to note that the example presented above is a fairly specific utilization of callbacks in the context of scientific software where one may want to inject custom logic inside a core program loop. On the other hand, the concept is also widely prevalent in video game programming and web development, typically presenting itself in the context of asynchronous programming. While the key idea still remains that the invocation of a function is deferred until some other function is completed, the resulting interfaces may take up a different form than we see here. Perhaps that is a topic for another time.