Model module
Model base
- class emsa.model.model_base.EpidemicModelBase(data, model_struct: Dict[str, Any])
Bases:
ABC- aggregate_by_age(solution, comp)
Aggregate the solution by age for a compartment.
- Parameters:
solution (torch.Tensor) – Solution tensor.
comp (str) – Compartment name.
- Returns:
Aggregated solution.
- Return type:
torch.Tensor
- get_compartments() list
Get the list of compartments.
- Returns:
List of compartments.
- Return type:
List[str]
- get_initial_values_from_dict(init_val_dict: dict) FloatTensor
Retrieve the initial values for the model based on the provided values in the corresponding configuration json.
- Returns:
Initial values of the model.
- Return type:
torch.Tensor
- get_sol_from_ode(y0: Tensor, t_eval: Tensor, odefun: Callable) Solution
Solve the ODE system using the Euler method with a step size of 1.
- Parameters:
y0 (torch.Tensor) – Initial values.
t_eval (torch.Tensor) – Evaluation times.
odefun (Any) – ODE function.
- Returns:
Solution of the ODE system.
- Return type:
Any
- abstract get_solution(y0, t_eval, **kwargs)
Get the solution of the epidemic model.
- Parameters:
y0 (torch.Tensor) – Initial values.
t_eval (torch.Tensor) – Evaluation times.
**kwargs – Additional keyword arguments.
- Returns:
Solution of the model.
- Return type:
Any
- idx(state: str) BoolTensor
Get the index tensor for a given state.
- Parameters:
state (str) – State name.
- Returns:
Index tensor.
- Return type:
torch.BoolTensor
- initialize_matrices()
Initialize the matrices used in the model.
- validate_params()
- visualize_transmission_graph()
- emsa.model.model_base.get_substates(n_substates, comp_name)
Get the list of substates for a compartment.
- Parameters:
n_substates (int) – Number of substates.
comp_name (str) – Compartment name.
- Returns:
List of state names.
- Return type:
List[str]
Epidemic model
- class emsa.model.epidemic_model.EpidemicModel(data, model_struct: dict)
Bases:
EpidemicModelBase- basic_ode(t: Tensor, y: Tensor) Tensor
Basic ODE function without vaccination.
- Parameters:
t (torch.Tensor) – Current time.
y (torch.Tensor) – Current state.
- Returns:
Derivative of the system.
- Return type:
torch.Tensor
- get_solution(y0: Tensor, t_eval: Tensor, **kwargs) Solution
Get the solution of the ODE using the initial conditions and evaluation times, using the ODE solver in EpidemicModelBase .
- Parameters:
y0 (torch.Tensor) – Initial state.
t_eval (torch.Tensor) – Times at which to evaluate the solution.
**kwargs – Additional keyword arguments.
- Returns:
Solution of the ODE.
- Return type:
torch.Tensor
Matrix generator
- class emsa.model.matrix_generator.MatrixGenerator(model: EpidemicModelBase, cm)
Bases:
objectClass responsible for generating the matrices used in the model.
Let y be size n_samples * n_eq, each row corresponding to a different simulation. The general formula we represent the system of ODEs with is the following:
y’ = (y @ T_1) * (y @ T_2) + y @ L,
where T_1 and T_2 are responsible for the transmission of the disease, and L defines the linear changes.
In addition to this, continuous vaccination can be written as
vacc = (y @ V_1) / (y @ V_2).
- cm
The contact matrix.
- Type:
torch.Tensor
- ps
A dictionary containing model parameters.
- Type:
Dict[str, float]
- data
Model data.
- Type:
Any
- state_data
Data related to states.
- Type:
Dict[str, Any]
- trans_data
Data related to transitions.
- Type:
List[Dict[str, Any]]
- tms_rules
Transmission rules.
- Type:
List[Dict[str, Any]]
- n_eq
The total number of compartments in the model.
- Type:
int
- n_age
The number of age groups.
- Type:
int
- n_comp
The number of states.
- Type:
int
- population
The total population.
- Type:
torch.Tensor
- device
The device to be used for computations.
- Type:
torch.device
- idx
A dictionary containing the indices of different compartments.
- Type:
Dict[str, int]
- c_idx
A dictionary containing the indices of different compartments’ components.
- Type:
Dict[str, int]
- _get_comp_slice(comp)
Get a slice representing the indices of a given compartment.
- _get_end_state(comp)
Get the string representing the last state of a given compartment.
- _get_trans_param_dict()
Get a dictionary of transition params for different compartments.
- generate_matrix(matrix_name: str) Tensor
- get_A() Tensor
- Returns:
When multiplied with y, the resulting tensor contains the rate of transmission for the susceptibles of age group i at the indices of compartments s^i and e_0^i
- Return type:
Torch.Tensor
- get_B() Tensor
- get_T(cm=None) Tensor
- get_V_1(daily_vac=None) Tensor
- get_V_2() Tensor
- get_end_state(comp: str) str
Get the string representing the last state of a given compartment.
- Parameters:
comp (str) – The compartment name.
- Returns:
The last state of the compartment.
- Return type:
str
- emsa.model.matrix_generator.generate_transition_block(transition_param: float, n_states: int) Tensor
Generate a transition block for the transition matrix.
- Parameters:
transition_param – The transition parameter value.
n_states – The number of states in the block.
- Returns:
The transition block.
- Return type:
torch.Tensor
- emsa.model.matrix_generator.generate_transition_matrix(states_dict: Dict[str, Any], parameters: Dict[str, float], n_age: int, n_comp: int, c_idx: Dict[str, int]) Tensor
Generate the transition matrix for the model.
- Parameters:
states_dict (Dict[str, Any]) – Dictionary of states.
parameters (Dict[str, float]) – A dictionary containing model parameters.
n_age (int) – The number of age groups.
n_comp (int) – The number of compartments.
c_idx (Dict[str, int]) – A dictionary containing the indices of different compartments.
- Returns:
The transition matrix.
- Return type:
torch.Tensor
- emsa.model.matrix_generator.get_inf_mul(tms_rule: dict, data) Tensor
Get the infection multiplier.
- Parameters:
tms_rule (Dict[str, Any]) – Transmission rule.
data (Any) – Model data.
- Returns:
Infection multiplier.
- Return type:
torch.Tensor
- emsa.model.matrix_generator.get_param_mul(trans_params: List[str], params: dict) float
Get the transition parameters multiplier.
- Parameters:
trans_params (List[str]) – Distribution parameters keys.
params (Dict[str, float]) – Model parameters.
- Returns:
Transition parameters multiplier.
- Return type:
float
- emsa.model.matrix_generator.get_susc_mul(tms_rule: dict, data) Tensor
Get the susceptibility multiplier.
- Parameters:
tms_rule (Dict[str, Any]) – Transmission rule.
data (Any) – Model data.
- Returns:
Susceptibility multiplier.
- Return type:
torch.Tensor
R0 calculator
- class emsa.model.r0_calculator.R0Generator(data, model_struct: Dict[str, Any])
Bases:
object- get_eig_val(susceptibles: Tensor, population: Tensor, contact_mtx: Tensor) float
Compute the dominant eigenvalue of the next-generation matrix (NGM).
- Parameters:
susceptibles (torch.Tensor) – Susceptible population.
population (torch.Tensor) – Total population.
contact_mtx (torch.Tensor) – Contact matrix.
- Returns:
Dominant eigenvalue representing the basic reproduction number (R0).
- Return type:
float
- get_infected_states()
Get the list of infected states.
- Returns:
List of infected states.
- Return type:
list
- isinf_state(state)