Benchmarking
Contents
7.5. Benchmarking#
Tip
This notebook benchmarks JAX on a single CPU core. Compare with Julia results as reported in ComPWA/polarimetry#27. See also the Extended benchmark #68 discussion.
Note
This notebook uses only one run and one loop for %timeit
, because JAX seems to cache its return values.
Import Python libraries
from __future__ import annotations
import logging
from collections import defaultdict
import numpy as np
import pandas as pd
import sympy as sp
from IPython.display import Markdown
from psutil import cpu_count
from polarimetry import formulate_polarimetry
from polarimetry.data import (
create_data_transformer,
generate_meshgrid_sample,
generate_phasespace_sample,
)
from polarimetry.io import (
mute_jax_warnings,
perform_cached_doit,
perform_cached_lambdify,
)
from polarimetry.lhcb import (
load_model_builder,
load_model_parameters,
load_three_body_decay,
)
from polarimetry.lhcb.particle import load_particles
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
mute_jax_warnings()
model_choice = 0
model_file = "../../data/model-definitions.yaml"
particles = load_particles("../../data/particle-definitions.yaml")
amplitude_builder = load_model_builder(model_file, particles, model_choice)
imported_parameter_values = load_model_parameters(
model_file, amplitude_builder.decay, model_choice, particles
)
reference_subsystem = 1
model = amplitude_builder.formulate(reference_subsystem)
model.parameter_defaults.update(imported_parameter_values)
timing_parametrized = defaultdict(dict)
timing_substituted = defaultdict(dict)
print("Physical cores:", cpu_count(logical=False))
print("Total cores:", cpu_count(logical=True))
Physical cores: 8
Total cores: 8
%%time
polarimetry_exprs = formulate_polarimetry(amplitude_builder, reference_subsystem)
unfolded_polarimetry_exprs = [
perform_cached_doit(expr.doit().xreplace(model.amplitudes))
for expr in polarimetry_exprs
]
unfolded_intensity_expr = perform_cached_doit(model.full_expression)
CPU times: user 25 s, sys: 0 ns, total: 25 s
Wall time: 25.1 s
7.5.1. DataTransformer
performance#
n_events = 100_000
phsp_sample = generate_phasespace_sample(model.decay, n_events, seed=0)
transformer = create_data_transformer(model)
%timeit -n1 -r1 transformer(phsp_sample) # first run, so no cache and JIT-compilation
%timeit -n1 -r1 transformer(phsp_sample) # second run with cache
%timeit -n1 -r1 transformer(phsp_sample) # third run with cache
phsp_sample = transformer(phsp_sample)
random_point = {k: v[0] if len(v.shape) > 0 else v for k, v in phsp_sample.items()}
524 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
25.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
25.6 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
res = 54
grid_sample = generate_meshgrid_sample(model.decay, res)
%timeit -n1 -r1 transformer(grid_sample) # first run, without cache, but already compiled
%timeit -n1 -r1 transformer(grid_sample) # second run with cache
%timeit -n1 -r1 transformer(grid_sample) # third run with cache
grid_sample = transformer(grid_sample)
483 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
2.73 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
1.99 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
7.5.2. Parametrized function#
Total number of mathematical operations:
\(\alpha_x\): 133,630
\(\alpha_y\): 133,634
\(\alpha_z\): 133,630
\(I_\mathrm{tot}\): 43,198
%%time
parametrized_polarimetry_funcs = [
perform_cached_lambdify(
expr,
parameters=model.parameter_defaults,
backend="jax",
)
for expr in unfolded_polarimetry_exprs
]
parametrized_intensity_func = perform_cached_lambdify(
unfolded_intensity_expr,
parameters=model.parameter_defaults,
backend="jax",
)
CPU times: user 23.5 ms, sys: 0 ns, total: 23.5 ms
Wall time: 23.2 ms
rng = np.random.default_rng(seed=0)
original_parameters = dict(parametrized_intensity_func.parameters)
modified_parameters = {
k: rng.uniform(0.9, 1.1) * v
for k, v in parametrized_intensity_func.parameters.items()
}
7.5.2.1. One data point#
7.5.2.1.1. JIT-compilation#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(random_point)
<TimeitResult : 2.05 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](random_point)
array = parametrized_polarimetry_funcs[1](random_point)
array = parametrized_polarimetry_funcs[2](random_point)
<TimeitResult : 10.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.2.1.2. Compiled performance#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(random_point)
<TimeitResult : 1.41 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](random_point)
array = parametrized_polarimetry_funcs[1](random_point)
array = parametrized_polarimetry_funcs[2](random_point)
<TimeitResult : 2.2 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.2.2. 54x54 grid sample#
7.5.2.2.1. Compiled but uncached#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(grid_sample)
<TimeitResult : 2.31 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](grid_sample)
array = parametrized_polarimetry_funcs[1](grid_sample)
array = parametrized_polarimetry_funcs[2](grid_sample)
<TimeitResult : 13.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.2.2.2. Second run with cache#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(grid_sample)
<TimeitResult : 3.84 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](grid_sample)
array = parametrized_polarimetry_funcs[1](grid_sample)
array = parametrized_polarimetry_funcs[2](grid_sample)
<TimeitResult : 19 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.2.3. 100.000 event phase space sample#
7.5.2.3.1. Compiled but uncached#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)
<TimeitResult : 2.33 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)
<TimeitResult : 13.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.2.3.2. Second run with cache#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)
<TimeitResult : 63.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)
<TimeitResult : 235 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.2.4. Recompilation after parameter modification#
parametrized_intensity_func.update_parameters(modified_parameters)
for func in parametrized_polarimetry_funcs:
func.update_parameters(modified_parameters)
7.5.2.4.1. Compiled but uncached#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)
<TimeitResult : 2.33 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)
<TimeitResult : 13.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.2.4.2. Second run with cache#
%%timeit -n1 -r1 -q -o
array = parametrized_intensity_func(phsp_sample)
<TimeitResult : 53.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = parametrized_polarimetry_funcs[0](phsp_sample)
array = parametrized_polarimetry_funcs[1](phsp_sample)
array = parametrized_polarimetry_funcs[2](phsp_sample)
<TimeitResult : 286 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
parametrized_intensity_func.update_parameters(original_parameters)
for func in parametrized_polarimetry_funcs:
func.update_parameters(original_parameters)
7.5.3. All parameters substituted#
subs_polarimetry_exprs = [
expr.xreplace(model.parameter_defaults) for expr in unfolded_polarimetry_exprs
]
subs_intensity_expr = unfolded_intensity_expr.xreplace(model.parameter_defaults)
Number of mathematical operations after substituting all parameters:
\(\alpha_x\): 29,552
\(\alpha_y\): 29,556
\(\alpha_z\): 29,552
\(I_\mathrm{tot}\): 9,624
%%time
polarimetry_funcs = [
perform_cached_lambdify(expr, backend="jax") for expr in subs_polarimetry_exprs
]
intensity_func = perform_cached_lambdify(subs_intensity_expr, backend="jax")
CPU times: user 11.8 ms, sys: 0 ns, total: 11.8 ms
Wall time: 12.2 ms
7.5.3.1. One data point#
7.5.3.1.1. JIT-compilation#
%%timeit -n1 -r1 -q -o
array = intensity_func(random_point)
<TimeitResult : 1.48 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](random_point)
array = polarimetry_funcs[1](random_point)
array = polarimetry_funcs[2](random_point)
<TimeitResult : 7.45 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.3.1.2. Compiled performance#
%%timeit -n1 -r1 -q -o
array = intensity_func(random_point)
<TimeitResult : 282 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](random_point)
array = polarimetry_funcs[1](random_point)
array = polarimetry_funcs[2](random_point)
<TimeitResult : 303 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.3.2. 54x54 grid sample#
7.5.3.2.1. Compiled but uncached#
%%timeit -n1 -r1 -q -o
array = intensity_func(grid_sample)
<TimeitResult : 1.62 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](grid_sample)
array = polarimetry_funcs[1](grid_sample)
array = polarimetry_funcs[2](grid_sample)
<TimeitResult : 8.64 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.3.2.2. Second run with cache#
%%timeit -n1 -r1 -q -o
array = intensity_func(grid_sample)
<TimeitResult : 4.77 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](grid_sample)
array = polarimetry_funcs[1](grid_sample)
array = polarimetry_funcs[2](grid_sample)
<TimeitResult : 23.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.3.3. 100.000 event phase space sample#
7.5.3.3.1. Compiled but uncached#
%%timeit -n1 -r1 -q -o
array = intensity_func(phsp_sample)
<TimeitResult : 1.69 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](phsp_sample)
array = polarimetry_funcs[1](phsp_sample)
array = polarimetry_funcs[2](phsp_sample)
<TimeitResult : 8.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.3.3.2. Second run with cache#
%%timeit -n1 -r1 -q -o
array = intensity_func(phsp_sample)
<TimeitResult : 46.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
%%timeit -n1 -r1 -q -o
array = polarimetry_funcs[0](phsp_sample)
array = polarimetry_funcs[1](phsp_sample)
array = polarimetry_funcs[2](phsp_sample)
<TimeitResult : 301 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)>
7.5.4. Summary#
Show code cell source
def collect_sorted_row_title() -> list[str]:
row_titles = {}
row_titles.update(timing_parametrized["intensity"])
row_titles.update(timing_parametrized["polarimetry"])
row_titles.update(timing_substituted["intensity"])
row_titles.update(timing_substituted["polarimetry"])
return list(row_titles)
def remove_loop_info(timing) -> str:
if timing is None:
return ""
pattern = " ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)"
return str(timing).replace(pattern, "")
row_titles = collect_sorted_row_title()
values = [
(
remove_loop_info(timing_parametrized["intensity"].get(row)),
remove_loop_info(timing_parametrized["polarimetry"].get(row)),
remove_loop_info(timing_substituted["intensity"].get(row)),
remove_loop_info(timing_substituted["polarimetry"].get(row)),
)
for row in row_titles
]
columns = pd.MultiIndex.from_tuples(
[
("parametrized", "I"),
("parametrized", "ɑ"),
("substituted", "I"),
("substituted", "ɑ"),
],
)
df = pd.DataFrame(values, index=row_titles, columns=columns)
df.style.set_table_styles(
[
dict(selector="th", props=[("text-align", "left")]),
dict(selector="td", props=[("text-align", "left")]),
]
)
parametrized | substituted | |||
---|---|---|---|---|
I | ɑ | I | ɑ | |
random point (compilation) | 2.05 s | 10.8 s | 1.48 s | 7.45 s |
random point (cached) | 1.41 ms | 2.2 ms | 282 µs | 303 µs |
54x54 grid | 2.31 s | 13.3 s | 1.62 s | 8.64 s |
54x54 grid (cached) | 3.84 ms | 19 ms | 4.77 ms | 23.1 ms |
100,000 phsp | 2.33 s | 13.1 s | 1.69 s | 8.9 s |
100,000 phsp (cached) | 63.5 ms | 235 ms | 46.8 ms | 301 ms |
modified 100,000 phsp | 2.33 s | 13.3 s | ||
modified 100,000 phsp (cached) | 53.9 ms | 286 ms |