Basic Usage¶
This section provides a minimal example of training a toy model of superposition using the tms_kit
library.
Minimal Example¶
Setup¶
First, we’ll import the dependencies.
import torch
from tms_kit.loss import ImportanceWeightedLoss
from tms_kit.model import Model
from tms_kit.data import IIDFeatureGenerator
from tms_kit.optimize import optimize
from tms_kit.tms import TMS
from tms_kit.utils import utils
from tms_kit.utils.device import set_device
Defining hyperparameters¶
Next, we’ll define the important properties.
# Define the configuration
set_device('cpu')
n_inst = 10
n_features = 5
d_hidden = 2
feature_probability = 0.01 * torch.ones(n_inst, n_features)
feature_importance = 1.0 * torch.ones(n_inst, n_features)
- The hyperparameters are as follows:
n_inst
refers to the number of separate instances of toy models that we train. Our code facilitates training multiple instances in parallel to speed up experimentation.n_features
refers to the number of features in the input data. This is expected to be constant across all instances.d_hidden
refers to the dimensionality of the hidden layer in the toy model. Again, this is expected to be constant across all instances.feature_probability
refers to the probability of each feature being present in the input data. Here, we’re using a uniform probability of 0.01 for each feature.feature_importance
refers to the importance of each feature in the loss function. Here, we’re using a uniform importance of 1.0 for each feature.
Defining the TMS experiment components¶
Next, we’ll define a TMS subclass which establishes the overall settings for the TMS experiment.
# Define a TMS subclass with all the necessary components
class BottleneckTMS(TMS):
def __init__(self):
self.model = Model(n_features = n_features, n_inst = n_inst, d_hidden = d_hidden)
self.loss_calc = ImportanceWeightedLoss(n_features = n_features, feature_importance = feature_importance)
self.data_gen = IIDFeatureGenerator(n_features = n_features, n_inst = n_inst, feature_probability = feature_probability)
- Each
TMS
instance bundles three components: A
Model
that represents the toy model of superposition being trainedA
DataGenerator
that generates batches of data for trainingA
LossCalculator
that computes the loss for the model
Here, we’re using pre-defined implementations for the model, data generator, and loss calculator. These can be replaced with different implementations to easily change different aspects of the experiment setting.
Training the TMS¶
Finally, training the TMS is as simple as creating an instance of the subclass and calling
the optimize
function.
# Train a TMS
tms = BottleneckTMS()
optimize(tms)
Inspecting the TMS features¶
Lastly, we can inspect the results of the training by plotting the features learned by the model in 2D space.
# Inspect a TMS
fig, ax = utils.plot_features_in_2d(
tms.model.W,
colors=feature_importance,
title=f"Superposition: {n_features} features represented in 2D space",
subplot_titles=[f"1 - S = {i:.3f}" for i in feature_probability[:, 0]],
)
utils.save_figure(fig, "5_2_superposition.png")
The resulting plot is as follows:
The full code for this example can be found in the experiments/demo
directory of the repository, available here.