Skip to content

Class TrainableDictionary

The TrainableDictionary class is a subclass of Dictionary class. It contains a trainable neural network and a set of \(N_y\) non-trainable observable functions (see ObservableFunction). The dictionary can be represented by the mapping \(\Psi: \mathbb{R}^{N \times N_x} \rightarrow \mathbb{R}^{N \times N_{\psi}}\).

API Documentation

Bases: Dictionary

__init__(network, observable_func, dim_input, dim_output)

A trainable dictionary

Parameters:

Name Type Description Default
network Module

The trainable network.

required
observable_func torch.Tensor -> torch.Tensor

Observable functions \((N, N_x) \rightarrow (N, N_{\psi})\)

required
dim_input int

Input dimension \(N_x\).

required
dim_output int

Output dimension \(N_\psi\).

required
Notes

The constant observable function \(\mathbf{1}\) is included in the TrainableDictionary by default, so users don't need to define it explicitly.

eval()

Set the network to evaluation mode.

load(path)

Load the network from a file.

Parameters:

Name Type Description Default
path str

The path to load the network.

required

parameters()

Return the parameters of the network.

Returns:

Type Description
Iterator[Parameter]

The parameters of the network.

save(path)

Save the network to a file.

Parameters:

Name Type Description Default
path str

The path to save the network.

required

train()

Set the network to training mode.