# Benchmarking

```{autolink-concat}
```

:::{tip}
This notebook benchmarks JAX on a **single CPU core**. Compare with Julia results as reported in [ComPWA/polarimetry#27](https://github.com/ComPWA/polarimetry/issues/27). See also the [Extended benchmark #68](https://github.com/ComPWA/polarimetry/discussions/68) discussion.
:::

:::{note}
This notebook uses only one run and one loop for [`%timeit`](https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-timeit), because JAX [seems to cache its return values](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
:::

In [None]:
from __future__ import annotations

import logging
from collections import defaultdict

import numpy as np
import pandas as pd
import sympy as sp
from IPython.display import Markdown
from psutil import cpu_count

from polarimetry import formulate_polarimetry
from polarimetry.data import (
 create_data_transformer,
 generate_meshgrid_sample,
 generate_phasespace_sample,
)
from polarimetry.io import (
 mute_jax_warnings,
 perform_cached_doit,
 perform_cached_lambdify,
)
from polarimetry.lhcb import (
 load_model_builder,
 load_model_parameters,
 load_three_body_decay,
)
from polarimetry.lhcb.particle import load_particles

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
mute_jax_warnings()

model_choice = 0
model_file = "../../data/model-definitions.yaml"
particles = load_particles("../../data/particle-definitions.yaml")
amplitude_builder = load_model_builder(model_file, particles, model_choice)
imported_parameter_values = load_model_parameters(
 model_file, amplitude_builder.decay, model_choice, particles
)
reference_subsystem = 1
model = amplitude_builder.formulate(reference_subsystem)
model.parameter_defaults.update(imported_parameter_values)

timing_parametrized = defaultdict(dict)
timing_substituted = defaultdict(dict)

print("Physical cores:", cpu_count(logical=False))
print("Total cores:", cpu_count(logical=True))

In [None]:
%%time
polarimetry_exprs = formulate_polarimetry(amplitude_builder, reference_subsystem)
unfolded_polarimetry_exprs = [
 perform_cached_doit(expr.doit().xreplace(model.amplitudes))
 for expr in polarimetry_exprs
]
unfolded_intensity_expr = perform_cached_doit(model.full_expression)

## {class}`~tensorwaves.interface.DataTransformer` performance

In [None]:
n_events = 100_000
phsp_sample = generate_phasespace_sample(model.decay, n_events, seed=0)
transformer = create_data_transformer(model)
%timeit -n1 -r1 transformer(phsp_sample) # first run, so no cache and JIT-compilation
%timeit -n1 -r1 transformer(phsp_sample) # second run with cache
%timeit -n1 -r1 transformer(phsp_sample) # third run with cache
phsp_sample = transformer(phsp_sample)
random_point = {k: v[0] if len(v.shape) > 0 else v for k, v in phsp_sample.items()}

In [None]:
res = 54
grid_sample = generate_meshgrid_sample(model.decay, res)
%timeit -n1 -r1 transformer(grid_sample) # first run, without cache, but already compiled
%timeit -n1 -r1 transformer(grid_sample) # second run with cache
%timeit -n1 -r1 transformer(grid_sample) # third run with cache
grid_sample = transformer(grid_sample)

## Parametrized function

:::{margin}
Compare {ref}`appendix/benchmark:All parameters substituted`.
:::

In [None]:
src = "Total number of mathematical operations:\n"
for xyz, expr in enumerate(unfolded_polarimetry_exprs):
 n_operations = sp.count_ops(expr)
 src += Rf"- $\alpha_{'xyz'[xyz]}$: {n_operations:,}" + "\n"
n_operations = sp.count_ops(unfolded_intensity_expr)
src += Rf"- $I_\mathrm{{tot}}$: {n_operations:,}"
Markdown(src)

In [None]:
%%time
parametrized_polarimetry_funcs = [
 perform_cached_lambdify(
 expr,
 parameters=model.parameter_defaults,
 backend="jax",
 )
 for expr in unfolded_polarimetry_exprs
]
parametrized_intensity_func = perform_cached_lambdify(
 unfolded_intensity_expr,
 parameters=model.parameter_defaults,
 backend="jax",
)

In [None]:
rng = np.random.default_rng(seed=0)
original_parameters = dict(parametrized_intensity_func.parameters)
modified_parameters = {
 k: rng.uniform(0.9, 1.1) * v
 for k, v in parametrized_intensity_func.parameters.items()
}

### One data point

#### JIT-compilation

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(random_point)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](random_point)
array = parametrized_polarimetry_funcs[1](random_point)
array = parametrized_polarimetry_funcs[2](random_point)

In [None]:
timing_parametrized["intensity"]["random point (compilation)"] = __
timing_parametrized["polarimetry"]["random point (compilation)"] = _

#### Compiled performance

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(random_point)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](random_point)
array = parametrized_polarimetry_funcs[1](random_point)
array = parametrized_polarimetry_funcs[2](random_point)

In [None]:
timing_parametrized["intensity"]["random point (cached)"] = __
timing_parametrized["polarimetry"]["random point (cached)"] = _

### 54x54 grid sample

#### Compiled but uncached

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(grid_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](grid_sample)
array = parametrized_polarimetry_funcs[1](grid_sample)
array = parametrized_polarimetry_funcs[2](grid_sample)

In [None]:
timing_parametrized["intensity"][f"{res}x{res} grid"] = __
timing_parametrized["polarimetry"][f"{res}x{res} grid"] = _

#### Second run with cache

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(grid_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](grid_sample)
array = parametrized_polarimetry_funcs[1](grid_sample)
array = parametrized_polarimetry_funcs[2](grid_sample)

In [None]:
timing_parametrized["intensity"][f"{res}x{res} grid (cached)"] = __
timing_parametrized["polarimetry"][f"{res}x{res} grid (cached)"] = _

### 100.000 event phase space sample

#### Compiled but uncached

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)

In [None]:
timing_parametrized["intensity"][f"{n_events:,} phsp"] = __
timing_parametrized["polarimetry"][f"{n_events:,} phsp"] = _

#### Second run with cache

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)

In [None]:
timing_parametrized["intensity"][f"{n_events:,} phsp (cached)"] = __
timing_parametrized["polarimetry"][f"{n_events:,} phsp (cached)"] = _

### Recompilation after parameter modification

In [None]:
parametrized_intensity_func.update_parameters(modified_parameters)
for func in parametrized_polarimetry_funcs:
 func.update_parameters(modified_parameters)

#### Compiled but uncached

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)

In [None]:
timing_parametrized["intensity"][f"modified {n_events:,} phsp"] = __
timing_parametrized["polarimetry"][f"modified {n_events:,} phsp"] = _

#### Second run with cache

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)

In [None]:
timing_parametrized["intensity"][f"modified {n_events:,} phsp (cached)"] = __
timing_parametrized["polarimetry"][f"modified {n_events:,} phsp (cached)"] = _

In [None]:
parametrized_intensity_func.update_parameters(original_parameters)
for func in parametrized_polarimetry_funcs:
 func.update_parameters(original_parameters)

## All parameters substituted

In [None]:
subs_polarimetry_exprs = [
 expr.xreplace(model.parameter_defaults) for expr in unfolded_polarimetry_exprs
]
subs_intensity_expr = unfolded_intensity_expr.xreplace(model.parameter_defaults)

:::{margin}
Compare {ref}`appendix/benchmark:Parametrized function`.
:::

In [None]:
src = "Number of mathematical operations after substituting all parameters:\n"
for xyz, expr in enumerate(subs_polarimetry_exprs):
 n_operations = sp.count_ops(expr)
 src += Rf"- $\alpha_{'xyz'[xyz]}$: {n_operations:,}" + "\n"
n_operations = sp.count_ops(subs_intensity_expr)
src += Rf"- $I_\mathrm{{tot}}$: {n_operations:,}"
Markdown(src)

In [None]:
%%time
polarimetry_funcs = [
 perform_cached_lambdify(expr, backend="jax") for expr in subs_polarimetry_exprs
]
intensity_func = perform_cached_lambdify(subs_intensity_expr, backend="jax")

### One data point

#### JIT-compilation

In [None]:
%%timeit -n1 -r1 -q -o
array = intensity_func(random_point)

In [None]:
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](random_point)
array = polarimetry_funcs[1](random_point)
array = polarimetry_funcs[2](random_point)

In [None]:
timing_substituted["intensity"]["random point (compilation)"] = __
timing_substituted["polarimetry"]["random point (compilation)"] = _

#### Compiled performance

In [None]:
%%timeit -n1 -r1 -q -o
array = intensity_func(random_point)

In [None]:
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](random_point)
array = polarimetry_funcs[1](random_point)
array = polarimetry_funcs[2](random_point)

In [None]:
timing_substituted["intensity"]["random point (cached)"] = __
timing_substituted["polarimetry"]["random point (cached)"] = _

### 54x54 grid sample

#### Compiled but uncached

In [None]:
%%timeit -n1 -r1 -q -o
array = intensity_func(grid_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](grid_sample)
array = polarimetry_funcs[1](grid_sample)
array = polarimetry_funcs[2](grid_sample)

In [None]:
timing_substituted["intensity"][f"{res}x{res} grid"] = __
timing_substituted["polarimetry"][f"{res}x{res} grid"] = _

#### Second run with cache

In [None]:
%%timeit -n1 -r1 -q -o
array = intensity_func(grid_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](grid_sample)
array = polarimetry_funcs[1](grid_sample)
array = polarimetry_funcs[2](grid_sample)

In [None]:
timing_substituted["intensity"][f"{res}x{res} grid (cached)"] = __
timing_substituted["polarimetry"][f"{res}x{res} grid (cached)"] = _

### 100.000 event phase space sample

#### Compiled but uncached

In [None]:
%%timeit -n1 -r1 -q -o
array = intensity_func(phsp_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](phsp_sample)
array = polarimetry_funcs[1](phsp_sample)
array = polarimetry_funcs[2](phsp_sample)

In [None]:
timing_substituted["intensity"][f"{n_events:,} phsp"] = __
timing_substituted["polarimetry"][f"{n_events:,} phsp"] = _

#### Second run with cache

In [None]:
%%timeit -n1 -r1 -q -o
array = intensity_func(phsp_sample)

In [None]:
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](phsp_sample)
array = polarimetry_funcs[1](phsp_sample)
array = polarimetry_funcs[2](phsp_sample)

In [None]:
timing_substituted["intensity"][f"{n_events:,} phsp (cached)"] = __
timing_substituted["polarimetry"][f"{n_events:,} phsp (cached)"] = _

## Summary

In [None]:
def collect_sorted_row_title() -> list[str]:
 row_titles = {}
 row_titles.update(timing_parametrized["intensity"])
 row_titles.update(timing_parametrized["polarimetry"])
 row_titles.update(timing_substituted["intensity"])
 row_titles.update(timing_substituted["polarimetry"])
 return list(row_titles)


def remove_loop_info(timing) -> str:
 if timing is None:
 return ""
 pattern = " ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)"
 return str(timing).replace(pattern, "")


row_titles = collect_sorted_row_title()
values = [
 (
 remove_loop_info(timing_parametrized["intensity"].get(row)),
 remove_loop_info(timing_parametrized["polarimetry"].get(row)),
 remove_loop_info(timing_substituted["intensity"].get(row)),
 remove_loop_info(timing_substituted["polarimetry"].get(row)),
 )
 for row in row_titles
]
columns = pd.MultiIndex.from_tuples(
 [
 ("parametrized", "I"),
 ("parametrized", "ɑ"),
 ("substituted", "I"),
 ("substituted", "ɑ"),
 ],
)
df = pd.DataFrame(values, index=row_titles, columns=columns)
df.style.set_table_styles(
 [
 dict(selector="th", props=[("text-align", "left")]),
 dict(selector="td", props=[("text-align", "left")]),
 ]
)