Error propagation with Sympy

Published

September 17, 2020

Sympy is a Python module for symbolic computation (like Mathematica and Matlab) with an elegant Python design.

It can easily do mechanical tasks such as computing first derivatives for error propagation. We use this here to write a generic error propagation function in only 20+ lines.

import sympy
print(f"Sympy version {sympy.__version__}")
Sympy version 1.8

Independent variables

def value_and_covariance_gen(expr, variables):
    expr = sympy.parse_expr(expr)

    symbols = sympy.symbols(variables)
    cov_symbols = sympy.symbols(tuple("C_" + k for k in variables))
    expr2 = sum(expr.diff(s) ** 2 * c for s, c in zip(symbols, cov_symbols))
    expr2 = expr2.simplify() # recommended for speed and accuracy

    fval = sympy.lambdify(symbols, expr)
    fcov = sympy.lambdify(symbols + cov_symbols, expr2)

    def fn(**kwargs):
        x = tuple(v[0] for v in kwargs.values())
        c = tuple(v[1] for v in kwargs.values())
        return fval(*x), fcov(*x, *c)

    return fn


def value_and_covariance(expr, **kwargs):
    return value_and_covariance_gen(expr, tuple(kwargs))(**kwargs)

That’s all, folks!

value_and_covariance_gen generates a Python function that computes the value of the expression and the propagated covariance.

value_and_covariance is just a shortcut to generate and immediately call the function. If the generated function is called with different values several times it is more efficient to generate it once with the first command and then call it several times.

Generating the propagating function takes about 6ms on my computer. The generated Python code is as fast to evaluate as any other function that calls numpy functions.

Limitations of this implemention - Only independent (uncorrelated) inputs are supported here - Expression must consist of basic math (must be parsable by Sympy)

Example usage

value_and_covariance("a + b", a=(1, 0.1), b=(2, 0.2)
(3, 0.30000000000000004)
value_and_covariance("s / (s + b)", s=(5, 0.5), b=(10, 0.1))
(0.3333333333333333, 0.001037037037037037)