archetypes.torch.AA#

class archetypes.torch.AA(k, m, n, device='cpu')#

Archetype analysis implemented in PyTorch.

Parameters:
k: int

The number of archetypes to use.

m: int

The number of observations.

n: int

The number of variables.

device: str

The device to use for training the model. Defaults to “cpu”.

Attributes:
A

A coefficient matrix.

B

B coefficient matrix.

Z

The archetypes matrix.

Methods

train(data, n_epochs[, learning_rate])

Train the model.

property A#

A coefficient matrix.

Returns:
torch.Tensor
property B#

B coefficient matrix.

Returns:
torch.Tensor
property Z#

The archetypes matrix.

Returns:
torch.Tensor
train(data, n_epochs, learning_rate=0.01)#

Train the model.

Parameters:
data: torch.Tensor

The data to be used for training.

n_epochs: int

The number of epochs to train the model for.

learning_rate: float

The learning rate to use for training.