archetypes.torch.AA#
- class archetypes.torch.AA(k, m, n, device='cpu')#
Archetype analysis implemented in PyTorch.
- Parameters:
- k: int
The number of archetypes to use.
- m: int
The number of observations.
- n: int
The number of variables.
- device: str
The device to use for training the model. Defaults to “cpu”.
- Attributes:
Methods
train
(data, n_epochs[, learning_rate])Train the model.
- property A#
A coefficient matrix.
- Returns:
- torch.Tensor
- property B#
B coefficient matrix.
- Returns:
- torch.Tensor
- property Z#
The archetypes matrix.
- Returns:
- torch.Tensor
- train(data, n_epochs, learning_rate=0.01)#
Train the model.
- Parameters:
- data: torch.Tensor
The data to be used for training.
- n_epochs: int
The number of epochs to train the model for.
- learning_rate: float
The learning rate to use for training.