Class KoopmanDataSet
The class KoopmanDataSet
is a subclass of torch.utils.data.Dataset
,
it provides data for EDMDSolver
and EDMDDLSolver
.
API Documentation¶
Bases: Dataset
data_x
property
¶
Returns the generated data for the x component of the dataset.
labels
property
¶
Returns the labels of the dataset.
__getitem__(idx)
¶
Retrieve the data and label at the specified index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int
|
The index of the data item to retrieve. |
required |
Returns:
Type | Description |
---|---|
tuple
|
A tuple containing the data and its corresponding label at the specified index. |
__init__(dynamics, x_sample_func=torch.rand)
¶
Initializes the KoopmanDataSet object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dynamics
|
DiscreteDynamics
|
An object representing the dynamics of the system. |
required |
x_sample_func
|
callable
|
A function to sample initial states. |
rand
|
Attributes:
Name | Type | Description |
---|---|---|
_dynamics |
DiscreteDynamics
|
Stores the dynamics of the system. |
_generated |
bool
|
Indicates whether the dataset has been generated. |
_x_sample_func |
callable
|
Function to sample initial states. |
__len__()
¶
Returns the number of samples in the dataset.
Returns:
Type | Description |
---|---|
int
|
The number of samples in the dataset. |
generate_data(n_traj, traj_len, x_min, x_max, param, seed_x=11)
¶
Generates a dataset of trajectories for a dynamical system.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_traj
|
int
|
Number of trajectories to generate. |
required |
traj_len
|
int
|
Length of each trajectory. |
required |
x_min
|
float or Tensor
|
Minimum value(s) for initial state sampling. |
required |
x_max
|
float or Tensor
|
Maximum value(s) for initial state sampling. |
required |
param
|
Tensor
|
Parameters for the dynamics function. |
required |
seed_x
|
int
|
Random seed for initial state sampling. Defaults to 11. |
11
|
Returns:
Type | Description |
---|---|
None
|
The generated data is stored in the instance variables |
load(file)
¶
Load data from a pickle file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file
|
str
|
The path to the pickle file containing the data to be loaded. |
required |
save(file)
¶
Save the generated dataset to a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file
|
str
|
The path to the file where the dataset will be saved. |
required |