Class TrainableDictionary
The TrainableDictionary
class is a subclass of Dictionary class.
It contains a trainable neural network and a set of \(N_y\) non-trainable observable functions
(see ObservableFunction).
The dictionary can be represented by the mapping
\(\Psi: \mathbb{R}^{N \times N_x} \rightarrow \mathbb{R}^{N \times N_{\psi}}\).
API Documentation¶
Bases: Dictionary
__init__(network, observable_func, dim_input, dim_output)
¶
A trainable dictionary
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network
|
Module
|
The trainable network. |
required |
observable_func
|
torch.Tensor -> torch.Tensor
|
Observable functions \((N, N_x) \rightarrow (N, N_{\psi})\) |
required |
dim_input
|
int
|
Input dimension \(N_x\). |
required |
dim_output
|
int
|
Output dimension \(N_\psi\). |
required |
Notes
The constant observable function \(\mathbf{1}\) is included in the TrainableDictionary
by default, so users don't need to define it explicitly.
eval()
¶
Set the network to evaluation mode.
load(path)
¶
Load the network from a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
str
|
The path to load the network. |
required |
parameters()
¶
Return the parameters of the network.
Returns:
Type | Description |
---|---|
Iterator[Parameter]
|
The parameters of the network. |
save(path)
¶
Save the network to a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
str
|
The path to save the network. |
required |
train()
¶
Set the network to training mode.