Model module

Model base

class emsa.model.model_base.EpidemicModelBase(data, model_struct: Dict[str, Any])

Bases: ABC

aggregate_by_age(solution, comp)

Aggregate the solution by age for a compartment.

Parameters:
  • solution (torch.Tensor) – Solution tensor.

  • comp (str) – Compartment name.

Returns:

Aggregated solution.

Return type:

torch.Tensor

get_compartments() list

Get the list of compartments.

Returns:

List of compartments.

Return type:

List[str]

get_initial_values_from_dict(init_val_dict: dict) FloatTensor

Retrieve the initial values for the model based on the provided values in the corresponding configuration json.

Returns:

Initial values of the model.

Return type:

torch.Tensor

get_sol_from_ode(y0: Tensor, t_eval: Tensor, odefun: Callable) Solution

Solve the ODE system using the Euler method with a step size of 1.

Parameters:
  • y0 (torch.Tensor) – Initial values.

  • t_eval (torch.Tensor) – Evaluation times.

  • odefun (Any) – ODE function.

Returns:

Solution of the ODE system.

Return type:

Any

abstract get_solution(y0, t_eval, **kwargs)

Get the solution of the epidemic model.

Parameters:
  • y0 (torch.Tensor) – Initial values.

  • t_eval (torch.Tensor) – Evaluation times.

  • **kwargs – Additional keyword arguments.

Returns:

Solution of the model.

Return type:

Any

idx(state: str) BoolTensor

Get the index tensor for a given state.

Parameters:

state (str) – State name.

Returns:

Index tensor.

Return type:

torch.BoolTensor

initialize_matrices()

Initialize the matrices used in the model.

validate_params()
visualize_transmission_graph()
emsa.model.model_base.get_substates(n_substates, comp_name)

Get the list of substates for a compartment.

Parameters:
  • n_substates (int) – Number of substates.

  • comp_name (str) – Compartment name.

Returns:

List of state names.

Return type:

List[str]

Epidemic model

class emsa.model.epidemic_model.EpidemicModel(data, model_struct: dict)

Bases: EpidemicModelBase

basic_ode(t: Tensor, y: Tensor) Tensor

Basic ODE function without vaccination.

Parameters:
  • t (torch.Tensor) – Current time.

  • y (torch.Tensor) – Current state.

Returns:

Derivative of the system.

Return type:

torch.Tensor

get_solution(y0: Tensor, t_eval: Tensor, **kwargs) Solution

Get the solution of the ODE using the initial conditions and evaluation times, using the ODE solver in EpidemicModelBase .

Parameters:
  • y0 (torch.Tensor) – Initial state.

  • t_eval (torch.Tensor) – Times at which to evaluate the solution.

  • **kwargs – Additional keyword arguments.

Returns:

Solution of the ODE.

Return type:

torch.Tensor

Matrix generator

class emsa.model.matrix_generator.MatrixGenerator(model: EpidemicModelBase, cm)

Bases: object

Class responsible for generating the matrices used in the model.

Let y be size n_samples * n_eq, each row corresponding to a different simulation. The general formula we represent the system of ODEs with is the following:

y’ = (y @ T_1) * (y @ T_2) + y @ L,

where T_1 and T_2 are responsible for the transmission of the disease, and L defines the linear changes.

In addition to this, continuous vaccination can be written as

vacc = (y @ V_1) / (y @ V_2).

cm

The contact matrix.

Type:

torch.Tensor

ps

A dictionary containing model parameters.

Type:

Dict[str, float]

data

Model data.

Type:

Any

state_data

Data related to states.

Type:

Dict[str, Any]

trans_data

Data related to transitions.

Type:

List[Dict[str, Any]]

tms_rules

Transmission rules.

Type:

List[Dict[str, Any]]

n_eq

The total number of compartments in the model.

Type:

int

n_age

The number of age groups.

Type:

int

n_comp

The number of states.

Type:

int

population

The total population.

Type:

torch.Tensor

device

The device to be used for computations.

Type:

torch.device

idx

A dictionary containing the indices of different compartments.

Type:

Dict[str, int]

c_idx

A dictionary containing the indices of different compartments’ components.

Type:

Dict[str, int]

_get_comp_slice(comp)

Get a slice representing the indices of a given compartment.

_get_end_state(comp)

Get the string representing the last state of a given compartment.

_get_trans_param_dict()

Get a dictionary of transition params for different compartments.

generate_matrix(matrix_name: str) Tensor
get_A() Tensor
Returns:

When multiplied with y, the resulting tensor contains the rate of transmission for the susceptibles of age group i at the indices of compartments s^i and e_0^i

Return type:

Torch.Tensor

get_B() Tensor
get_T(cm=None) Tensor
get_V_1(daily_vac=None) Tensor
get_V_2() Tensor
get_end_state(comp: str) str

Get the string representing the last state of a given compartment.

Parameters:

comp (str) – The compartment name.

Returns:

The last state of the compartment.

Return type:

str

emsa.model.matrix_generator.generate_transition_block(transition_param: float, n_states: int) Tensor

Generate a transition block for the transition matrix.

Parameters:
  • transition_param – The transition parameter value.

  • n_states – The number of states in the block.

Returns:

The transition block.

Return type:

torch.Tensor

emsa.model.matrix_generator.generate_transition_matrix(states_dict: Dict[str, Any], parameters: Dict[str, float], n_age: int, n_comp: int, c_idx: Dict[str, int]) Tensor

Generate the transition matrix for the model.

Parameters:
  • states_dict (Dict[str, Any]) – Dictionary of states.

  • parameters (Dict[str, float]) – A dictionary containing model parameters.

  • n_age (int) – The number of age groups.

  • n_comp (int) – The number of compartments.

  • c_idx (Dict[str, int]) – A dictionary containing the indices of different compartments.

Returns:

The transition matrix.

Return type:

torch.Tensor

emsa.model.matrix_generator.get_inf_mul(tms_rule: dict, data) Tensor

Get the infection multiplier.

Parameters:
  • tms_rule (Dict[str, Any]) – Transmission rule.

  • data (Any) – Model data.

Returns:

Infection multiplier.

Return type:

torch.Tensor

emsa.model.matrix_generator.get_param_mul(trans_params: List[str], params: dict) float

Get the transition parameters multiplier.

Parameters:
  • trans_params (List[str]) – Distribution parameters keys.

  • params (Dict[str, float]) – Model parameters.

Returns:

Transition parameters multiplier.

Return type:

float

emsa.model.matrix_generator.get_susc_mul(tms_rule: dict, data) Tensor

Get the susceptibility multiplier.

Parameters:
  • tms_rule (Dict[str, Any]) – Transmission rule.

  • data (Any) – Model data.

Returns:

Susceptibility multiplier.

Return type:

torch.Tensor

R0 calculator

class emsa.model.r0_calculator.R0Generator(data, model_struct: Dict[str, Any])

Bases: object

get_eig_val(susceptibles: Tensor, population: Tensor, contact_mtx: Tensor) float

Compute the dominant eigenvalue of the next-generation matrix (NGM).

Parameters:
  • susceptibles (torch.Tensor) – Susceptible population.

  • population (torch.Tensor) – Total population.

  • contact_mtx (torch.Tensor) – Contact matrix.

Returns:

Dominant eigenvalue representing the basic reproduction number (R0).

Return type:

float

get_infected_states()

Get the list of infected states.

Returns:

List of infected states.

Return type:

list

isinf_state(state)