Filtration Learning
import torch
import torch_geometric.transforms as T
import multipers as mp
from torch_geometric.datasets import TUDataset
from os.path import expanduser
from torch_geometric.data import Data
import torch.nn as nn
import numpy as np
import multipers.ml.signed_measures as mms
import multipers.grids as mpg
torch.manual_seed(1)
## TODO : fixme
import warnings
warnings.filterwarnings("ignore")
This code is not meant to realize state of the art graph classification,
but to give an idea on how to use multipers
in a DL setting.
Dataset
dataset_name = "MUTAG"
def get_max_degree(dataset_name):
from torch_geometric.utils import degree
dataset = TUDataset(expanduser("~/Datasets/torch_geometric/"),dataset_name, use_node_attr=True,cleaned=True)
num_nodes = dataset.edge_index.max()+1 # only this matters, we're computing max_degree
assert not Data(edge_index=dataset.edge_index, num_nodes = num_nodes).is_directed()
a= degree(index = dataset.edge_index[0])
b = degree(index = dataset.edge_index[1])
assert (a==b).all() # because is_directed I guess
max_degree = a.max()
return int(max_degree)
transform = T.Compose([
# T.GDC(diffusion_kwargs={
# "method":"heat",
# "t":10,
# }),
T.Constant(1), # Constant_value
T.LocalDegreeProfile(),
# T.OneHotDegree(max_degree=get_max_degree(dataset_name)), # degree before removing edges
T.RemoveDuplicatedEdges(),
])
dataset = TUDataset(expanduser("~/Datasets/torch_geometric/"),dataset_name, use_node_attr=True,cleaned=False, transform=transform)
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
batch_size=len(dataset)
# batch_size=100
shuffled_dataset = dataset.shuffle()
dataset_size = len(dataset)
split = int(0.9*dataset_size)
train_dataset, test_dataset = dataset[:split], shuffled_dataset[split:]
train_dataset.x.shape, test_dataset.x.shape, dataset.x.shape
(torch.Size([3027, 7]), torch.Size([336, 7]), torch.Size([3371, 7]))
train = next(iter(DataLoader(dataset, batch_size=len(dataset))))
Some GCN
from torch_geometric.nn.models import GCN
out_channels1 = 3 ## Note: this is the number of parameter on which to compute Multiparameter Persistence; keep it low!
first_gcn = GCN(in_channels=train.x.shape[-1], hidden_channels=50, num_layers=5, out_channels=out_channels1)
## test
out1 = first_gcn.forward(train.x, train.edge_index, batch = train.batch)
out1.shape, out1.dtype
(torch.Size([3371, 3]), torch.float32)
Some topological layers
Torch Graphs to Signed Measures
from torch_geometric.utils import unbatch, unbatch_edge_index
class Graph2SMLayer(torch.nn.Module):
def __init__(
self,
degrees=[0, 1],
grid_strategy:str = "exact",
resolution: int = -1, # meant to crash if grid needs resolution
n_jobs=-1, # parallelize signed measure computations
normalize: bool = False,
):
super().__init__()
self.normalize = normalize
self.degrees = degrees
self.grid_strategy = grid_strategy
self.resolution = resolution
self.n_jobs = n_jobs
@torch.no_grad
def _simplextree_transform(
self,
nodes_indices,
nodes_filtrations,
edge_indices,
diff_grid,
):
"""
Given a graph g:Data, and filtrations = [node_filtrations, (opts) edge filtrations],
create the associated simplextree.
"""
num_parameters = nodes_filtrations.size(1)
numpy_node_filtrations = nodes_filtrations.detach().numpy()
st = mp.SimplexTreeMulti(num_parameters=num_parameters)
nodes = nodes_indices[None, :].detach().numpy()
st.insert_batch(nodes, numpy_node_filtrations)
edges = edge_indices.detach().numpy()
numpy_edges_filtrations = np.empty((0,0), dtype = st.dtype)
st.insert_batch(
edges,numpy_edges_filtrations
) # empty -> -inf
st = st.grid_squeeze(diff_grid, coordinate_values=True)
if num_parameters == 2:
st.collapse_edges(-1)
sms = mp.signed_measure(st, degrees=self.degrees, coordinate_measure=True)
return sms
def _get_diff_grids(self, node_filtration_iterable):
from multipers.torch.diff_grids import get_grid
todo = get_grid(self.grid_strategy)
return tuple(todo(x.T, self.resolution) for x in node_filtration_iterable)
@torch.no_grad
def data2coordinate_sms(
self, node_indices, nodes_filtrations, edges_indices, diff_grids
):
from joblib import Parallel, delayed
sms = Parallel(n_jobs=self.n_jobs, backend="threading")(
delayed(self._simplextree_transform)(
node_index,
nodes_filtration,
edge_index,
diff_grid,
)
for node_index, nodes_filtration, edge_index, diff_grid in zip(
node_indices, nodes_filtrations, edges_indices, diff_grids
)
)
return sms
def forward(
self, nodes_filtrations, edges_indices, batch_indices, *, simplex_tree_list=None
):
if batch_indices is None:
nodes_filtrations = [nodes_filtrations]
else:
from torch_geometric.utils import unbatch, unbatch_edge_index
node_indices = unbatch(torch.arange(nodes_filtrations.shape[0]), batch = batch_indices)
nodes_filtrations = unbatch(nodes_filtrations, batch=batch_indices)
edges_indices = unbatch_edge_index(edges_indices, batch=batch_indices)
grids = self._get_diff_grids(nodes_filtrations)
with torch.no_grad():
sms = self.data2coordinate_sms(
node_indices,
nodes_filtrations,
edges_indices,
diff_grids=grids,
)
# Joblib doesn't seem to be possible with pytorch
sms = tuple(
mpg.sms_in_grid(sm, diff_grid) for sm, diff_grid in zip(sms, grids)
)
sms = mms.SignedMeasureFormatter(
unrag=True, deep_format=True, normalize=self.normalize
).fit_transform(sms)
return sms
topological_layer = Graph2SMLayer(normalize = True, degrees=[0,1], n_jobs=1)
#test
sms = topological_layer.forward(out1, train.edge_index, train.batch)
sms.dtype
[KeOps] Warning : Cuda libraries were not detected on the system or could not be loaded ; using cpu only mode
torch.float32
Vectorization Layer
class SMConvLayer(torch.nn.Module):
def __init__(
self,
num_parameters: int,
num_axis: int,
dtype=torch.float64,
num_convolutions: int|None = None,
resolution:int = 5,
out_channels:int|None=None,
):
super().__init__()
self.dtype = dtype
self.num_parameters = num_parameters
self.resolution = resolution
self.num_convolutions = (
num_parameters if num_convolutions is None else num_convolutions
)
biases = torch.stack(
[
10*torch.diag(torch.rand(self.num_parameters, dtype=dtype))
for _ in range(self.num_convolutions)
],
dim=0,
).type(dtype)
self.Rs = nn.Parameter(
torch.randn(
# torch.rand(
# torch.zeros(
(self.num_convolutions, num_parameters, num_parameters),
dtype=dtype,
requires_grad=True,
)
+ biases # Maybe add a multiplicative factor ?
).type(dtype)
self.pts_to_evaluate = nn.Parameter(torch.stack([
torch.cartesian_prod(*(torch.linspace(0,1,resolution) for _ in range(num_parameters))).type(dtype)[None]
for _ in range(num_axis)
])).type(dtype) # initially pts on a grid
self.out_channels = num_parameters if out_channels is None else out_channels
self.final_reshape = nn.Sequential(nn.Linear(num_convolutions*num_axis*(resolution**num_parameters),out_channels), nn.ReLU())
def print_info(self):
print("SMConvLayer, bandwidths")
print(self.Rs)
def forward(
self,
sms,
):
from multipers.ml.convolutions import batch_signed_measure_convolutions
kernel_matrices = (
# This KDE implementation expects the inverse of the covariance for multiparameter kernels
(R.T @ R).inverse()
for R in self.Rs
)
## compute convolutions
convolutions = torch.stack(
[
batch_signed_measure_convolutions(
sms,
self.pts_to_evaluate,
bandwidth=k,
kernel="multivariate_gaussian",
)
for k in kernel_matrices
]
)
new_f = convolutions.swapaxes(0,2).flatten(1) # num_data, merged stuff
new_f = self.final_reshape(new_f)
return new_f # New node filtration values
vectorization_layer = SMConvLayer(num_parameters=first_gcn.out_channels, num_axis=len(topological_layer.degrees), dtype = sms.dtype, num_convolutions = 7, resolution = 5, out_channels=10)
# test
vectorization_layer(sms).shape
torch.Size([188, 10])
A graph filtration learning model
from torch.nn.functional import one_hot
from tqdm import tqdm
class GraphModel(nn.Module):
def __init__(
self,
in_channels:int,
out_channels:int,
num_parameters:int=2,
hidden_channels:int=50,
num_layers:int=2,
degrees:list[int]=[0,1],
num_convolutions:int = 5,
resolution:int=5,
):
super().__init__()
## in an ideal world, put the parameters in the init
self.first_gcn = GCN(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, out_channels=num_parameters)
self.topological_layer = Graph2SMLayer(normalize = True, degrees=[0,1])
self.vectorization_layer = SMConvLayer(num_parameters=num_parameters, num_axis=len(degrees), num_convolutions = num_convolutions, resolution = resolution, out_channels=num_convolutions, dtype = torch.float32)
self.classifier = nn.Sequential(
nn.Linear(num_convolutions, out_channels),
nn.ReLU(),
nn.Softmax(dim=-1),
)
def forward(self,data):
out1 = self.first_gcn.forward(data.x, data.edge_index, batch = data.batch)
sms = self.topological_layer.forward(out1, data.edge_index, data.batch)
out = self.vectorization_layer(sms)
out = self.classifier(out)
return out
_stuff = train
graphclassifier = GraphModel(
in_channels = _stuff.x.shape[1],
out_channels = np.unique(_stuff.y).shape[0],
hidden_channels=10,
num_layers=2,
num_parameters=2,
num_convolutions=2,
resolution=10,
)
graphclassifier
GraphModel(
(first_gcn): GCN(13, 2, num_layers=2)
(topological_layer): Graph2SMLayer()
(vectorization_layer): SMConvLayer(
(final_reshape): Sequential(
(0): Linear(in_features=400, out_features=2, bias=True)
(1): ReLU()
)
)
(classifier): Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): ReLU()
(2): Softmax(dim=-1)
)
)
# test
graphclassifier(train)[:5]
tensor([[0.4823, 0.5177],
[0.4811, 0.5189],
[0.4811, 0.5189],
[0.4806, 0.5194],
[0.4812, 0.5188]], grad_fn=<SliceBackward0>)
Learning
num_epoch = 100
batch_size = len(train_dataset)
data_loader = DataLoader(train_dataset,batch_size=batch_size)
graphclassifier.train()
loss = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(graphclassifier.parameters(), lr=1e-2)
losses = []
with tqdm(range(num_epoch)) as epoch:
for i in epoch:
for stuff in data_loader:
optim.zero_grad()
batch_labels = one_hot(stuff.y).type(torch.float32)
prediction = graphclassifier(stuff)
current_loss = loss(prediction, batch_labels)
with torch.no_grad():
real_classification = prediction.argmax(1)
cst = real_classification[0] if np.unique(real_classification).shape[0] == 1 else None
accuracy = (real_classification == stuff.y).type(torch.float32).mean(0)
losses.append(current_loss.detach().numpy())
epoch.set_description(f"Current acc {accuracy:.3f}, loss {current_loss.detach().numpy()}, {"" if cst is None else f"constant to {cst}"}")
current_loss.backward()
optim.step()
Current acc 0.775, loss 0.518334150314331, : 100%|█| 100/100 [01:12<00:00, 1.38
graphclassifier.eval()
test_stuff = next(iter(DataLoader(test_dataset,batch_size=len(test_dataset))))
prediction = graphclassifier(test_stuff)
verdad = one_hot(test_stuff.y).type(torch.float32)
loss(prediction, verdad), (prediction.argmax(1)==test_stuff.y).type(torch.float).abs().mean()
(tensor(0.5426, grad_fn=<DivBackward1>), tensor(0.6842))