from __future__ import annotations
import logging
import re
from typing import Pattern
import jax.numpy as jnp
from tensorwaves.interface import DataSample, ParametrizedFunction
_LOGGER = logging.getLogger(__name__)
[docs]def compute_sub_function(
func: ParametrizedFunction,
input_data: DataSample,
non_zero_couplings: list[Pattern],
):
old_parameters = dict(func.parameters)
pattern = rf"\\mathcal{{H}}.*\[(LS,)?(?!{'|'.join(non_zero_couplings)})"
set_parameter_to_zero(func, pattern)
array = func(input_data)
func.update_parameters(old_parameters)
return array
[docs]def set_parameter_to_zero(func: ParametrizedFunction, search_term: Pattern) -> None:
new_parameters = dict(func.parameters)
no_parameters_selected = True
for par_name in func.parameters:
if re.match(search_term, par_name) is not None:
new_parameters[par_name] = 0
no_parameters_selected = False
if no_parameters_selected:
_LOGGER.warning(f"All couplings were set to zero for search term {search_term}")
func.update_parameters(new_parameters)
[docs]def interference_intensity(func, data, chain1: list[str], chain2: list[str]) -> float:
I_interference = sub_intensity(func, data, chain1 + chain2)
I_chain1 = sub_intensity(func, data, chain1)
I_chain2 = sub_intensity(func, data, chain2)
return I_interference - I_chain1 - I_chain2
[docs]def sub_intensity(func, data, non_zero_couplings: list[str]):
intensity_array = compute_sub_function(func, data, non_zero_couplings)
return integrate_intensity(intensity_array)
[docs]def integrate_intensity(intensities) -> float:
flattened_intensities = intensities.flatten()
non_nan_intensities = flattened_intensities[~jnp.isnan(flattened_intensities)]
return float(jnp.sum(non_nan_intensities) / len(non_nan_intensities))