Filtration Learning

import torch
import torch_geometric.transforms as T
import multipers as mp
from torch_geometric.datasets import TUDataset
from os.path import expanduser
from torch_geometric.data import Data
import torch.nn as nn
import numpy as np
import multipers.ml.signed_measures as mms
import multipers.grids as mpg
torch.manual_seed(1)
## TODO : fixme
import warnings
warnings.filterwarnings("ignore")

This code is not meant to realize state of the art graph classification, but to give an idea on how to use multipers in a DL setting.

Dataset

dataset_name = "MUTAG"
def get_max_degree(dataset_name):
    from torch_geometric.utils import degree
    dataset = TUDataset(expanduser("~/Datasets/torch_geometric/"),dataset_name, use_node_attr=True,cleaned=True)
    num_nodes = dataset.edge_index.max()+1 # only this matters, we're computing max_degree
    assert not Data(edge_index=dataset.edge_index, num_nodes = num_nodes).is_directed()
    a= degree(index = dataset.edge_index[0])
    b = degree(index = dataset.edge_index[1])
    assert (a==b).all() # because is_directed I guess
    max_degree = a.max()
    return int(max_degree)
transform = T.Compose([
    # T.GDC(diffusion_kwargs={
    #     "method":"heat",
    #     "t":10,
    # }),
    T.Constant(1), # Constant_value
    T.LocalDegreeProfile(),
    # T.OneHotDegree(max_degree=get_max_degree(dataset_name)), # degree before removing edges
    T.RemoveDuplicatedEdges(),
])
dataset = TUDataset(expanduser("~/Datasets/torch_geometric/"),dataset_name, use_node_attr=True,cleaned=False, transform=transform)
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

batch_size=len(dataset)
# batch_size=100
shuffled_dataset = dataset.shuffle()
dataset_size = len(dataset)
split = int(0.9*dataset_size)
train_dataset, test_dataset = dataset[:split], shuffled_dataset[split:]
train_dataset.x.shape, test_dataset.x.shape, dataset.x.shape
(torch.Size([3027, 7]), torch.Size([336, 7]), torch.Size([3371, 7]))
train = next(iter(DataLoader(dataset, batch_size=len(dataset))))

Some GCN

from torch_geometric.nn.models import GCN
out_channels1 = 3 ## Note: this is the number of parameter on which to compute Multiparameter Persistence; keep it low!
first_gcn = GCN(in_channels=train.x.shape[-1], hidden_channels=50, num_layers=5, out_channels=out_channels1)
## test
out1 = first_gcn.forward(train.x, train.edge_index, batch = train.batch)
out1.shape, out1.dtype
(torch.Size([3371, 3]), torch.float32)

Some topological layers

Torch Graphs to Signed Measures

from torch_geometric.utils import unbatch, unbatch_edge_index
class Graph2SMLayer(torch.nn.Module):
    def __init__(
        self,
        degrees=[0, 1],
        grid_strategy:str = "exact",
        resolution: int = -1,  # meant to crash if grid needs resolution
        n_jobs=-1,  # parallelize signed measure computations
        normalize: bool = False,
    ):
        super().__init__()
        self.normalize = normalize
        self.degrees = degrees
        self.grid_strategy = grid_strategy
        self.resolution = resolution
        self.n_jobs = n_jobs

    @torch.no_grad
    def _simplextree_transform(
        self,
        nodes_indices,
        nodes_filtrations,
        edge_indices,
        diff_grid,
    ):
        """
        Given a graph g:Data, and filtrations = [node_filtrations, (opts) edge filtrations],
        create the associated simplextree.
        """
        num_parameters = nodes_filtrations.size(1)
        numpy_node_filtrations = nodes_filtrations.detach().numpy()
        st = mp.SimplexTreeMulti(num_parameters=num_parameters)
        nodes = nodes_indices[None, :].detach().numpy()
        st.insert_batch(nodes, numpy_node_filtrations)
        edges = edge_indices.detach().numpy()
        numpy_edges_filtrations = np.empty((0,0), dtype = st.dtype)
        st.insert_batch(
            edges,numpy_edges_filtrations
        )  # empty -> -inf
        st = st.grid_squeeze(diff_grid, coordinate_values=True)
        
        if num_parameters == 2:
            st.collapse_edges(-1)
        sms = mp.signed_measure(st, degrees=self.degrees, coordinate_measure=True)
        return sms

    def _get_diff_grids(self, node_filtration_iterable):
        from multipers.torch.diff_grids import get_grid
        todo = get_grid(self.grid_strategy)
        return tuple(todo(x.T, self.resolution) for x in node_filtration_iterable)

    @torch.no_grad
    def data2coordinate_sms(
        self, node_indices, nodes_filtrations, edges_indices, diff_grids
    ):
        from joblib import Parallel, delayed
        sms = Parallel(n_jobs=self.n_jobs, backend="threading")(
            delayed(self._simplextree_transform)(
                node_index,
                nodes_filtration,
                edge_index,
                diff_grid,
            )
            for node_index, nodes_filtration, edge_index, diff_grid in zip(
                node_indices, nodes_filtrations, edges_indices, diff_grids
            )
        )
        return sms


    def forward(
        self, nodes_filtrations, edges_indices, batch_indices, *, simplex_tree_list=None
    ):
        if batch_indices is None:
            nodes_filtrations = [nodes_filtrations]
        else:
            from torch_geometric.utils import unbatch, unbatch_edge_index
            node_indices = unbatch(torch.arange(nodes_filtrations.shape[0]), batch = batch_indices)
            nodes_filtrations = unbatch(nodes_filtrations, batch=batch_indices)
            edges_indices = unbatch_edge_index(edges_indices, batch=batch_indices)
            
        grids = self._get_diff_grids(nodes_filtrations)

        with torch.no_grad():
            sms = self.data2coordinate_sms(
                node_indices,
                nodes_filtrations,
                edges_indices,
                diff_grids=grids,
            )
        # Joblib doesn't seem to be possible with pytorch
        sms = tuple(
            mpg.sms_in_grid(sm, diff_grid) for sm, diff_grid in zip(sms, grids)
        )
        sms = mms.SignedMeasureFormatter(
            unrag=True, deep_format=True, normalize=self.normalize
        ).fit_transform(sms)

        return sms
topological_layer = Graph2SMLayer(normalize = True, degrees=[0,1], n_jobs=1)
#test
sms = topological_layer.forward(out1, train.edge_index, train.batch)
sms.dtype
[KeOps] Warning : Cuda libraries were not detected on the system or could not be loaded ; using cpu only mode
torch.float32

Vectorization Layer

class SMConvLayer(torch.nn.Module):
    def __init__(
        self,
        num_parameters: int,
        num_axis: int,
        dtype=torch.float64,
        num_convolutions: int|None = None,
        resolution:int = 5,
        out_channels:int|None=None,
    ):
        super().__init__()
        self.dtype = dtype
        self.num_parameters = num_parameters
        self.resolution = resolution
        self.num_convolutions = (
            num_parameters if num_convolutions is None else num_convolutions
        )

        biases = torch.stack(
            [
                10*torch.diag(torch.rand(self.num_parameters, dtype=dtype))
                for _ in range(self.num_convolutions)
            ],
            dim=0,
        ).type(dtype)
        self.Rs = nn.Parameter(
            torch.randn(
            # torch.rand(
            # torch.zeros(
                (self.num_convolutions, num_parameters, num_parameters),
                dtype=dtype,
                requires_grad=True,
            )
            + biases  # Maybe add a multiplicative factor ?
        ).type(dtype)
        self.pts_to_evaluate = nn.Parameter(torch.stack([
            torch.cartesian_prod(*(torch.linspace(0,1,resolution) for _ in range(num_parameters))).type(dtype)[None] 
            for _ in range(num_axis)
        ])).type(dtype) # initially pts on a grid
        self.out_channels = num_parameters if out_channels is None else out_channels
        self.final_reshape = nn.Sequential(nn.Linear(num_convolutions*num_axis*(resolution**num_parameters),out_channels), nn.ReLU())
        
    def print_info(self):
        print("SMConvLayer, bandwidths")
        print(self.Rs)

    def forward(
        self,
        sms,
    ):
        from multipers.ml.convolutions import batch_signed_measure_convolutions
        kernel_matrices = (
            # This KDE implementation expects the inverse of the covariance for multiparameter kernels
            (R.T @ R).inverse()
            for R in self.Rs
        )
        ## compute convolutions
        convolutions = torch.stack(
            [
                batch_signed_measure_convolutions(
                    sms,
                    self.pts_to_evaluate,
                    bandwidth=k,
                    kernel="multivariate_gaussian",
                )
                for k in kernel_matrices
            ]
        )
        new_f = convolutions.swapaxes(0,2).flatten(1) # num_data, merged stuff
        new_f = self.final_reshape(new_f)
        return new_f  # New node filtration values

        
        
vectorization_layer = SMConvLayer(num_parameters=first_gcn.out_channels, num_axis=len(topological_layer.degrees), dtype = sms.dtype, num_convolutions = 7, resolution = 5, out_channels=10)
# test 
vectorization_layer(sms).shape
torch.Size([188, 10])

A graph filtration learning model

from torch.nn.functional import one_hot
from tqdm import tqdm
class GraphModel(nn.Module):
    def __init__(
        self,
        in_channels:int,
        out_channels:int,
        num_parameters:int=2,
        hidden_channels:int=50,
        num_layers:int=2,
        degrees:list[int]=[0,1],
        num_convolutions:int = 5,
        resolution:int=5,
    ):
        super().__init__()
        ## in an ideal world, put the parameters in the init
        self.first_gcn = GCN(in_channels=in_channels, hidden_channels=hidden_channels, num_layers=num_layers, out_channels=num_parameters)
        self.topological_layer = Graph2SMLayer(normalize = True, degrees=[0,1])
        self.vectorization_layer = SMConvLayer(num_parameters=num_parameters, num_axis=len(degrees), num_convolutions = num_convolutions, resolution = resolution, out_channels=num_convolutions, dtype = torch.float32)
        self.classifier = nn.Sequential(
            nn.Linear(num_convolutions, out_channels), 
            nn.ReLU(), 
            nn.Softmax(dim=-1),
        )
    def forward(self,data):
        out1 = self.first_gcn.forward(data.x, data.edge_index, batch = data.batch)
        sms = self.topological_layer.forward(out1, data.edge_index, data.batch)
        out = self.vectorization_layer(sms)
        out = self.classifier(out)
        return out

_stuff = train
graphclassifier = GraphModel(
    in_channels = _stuff.x.shape[1], 
    out_channels = np.unique(_stuff.y).shape[0],
    hidden_channels=10,
    num_layers=2,
    num_parameters=2,
    num_convolutions=2,
    resolution=10,
)
graphclassifier
GraphModel(
  (first_gcn): GCN(13, 2, num_layers=2)
  (topological_layer): Graph2SMLayer()
  (vectorization_layer): SMConvLayer(
    (final_reshape): Sequential(
      (0): Linear(in_features=400, out_features=2, bias=True)
      (1): ReLU()
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=2, out_features=2, bias=True)
    (1): ReLU()
    (2): Softmax(dim=-1)
  )
)
# test
graphclassifier(train)[:5]
tensor([[0.4823, 0.5177],
        [0.4811, 0.5189],
        [0.4811, 0.5189],
        [0.4806, 0.5194],
        [0.4812, 0.5188]], grad_fn=<SliceBackward0>)

Learning

num_epoch = 100
batch_size = len(train_dataset)
data_loader = DataLoader(train_dataset,batch_size=batch_size)
graphclassifier.train()
loss = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(graphclassifier.parameters(), lr=1e-2)
losses = []
with tqdm(range(num_epoch)) as epoch:
   for i in epoch:
       for stuff in data_loader:
            optim.zero_grad()
            batch_labels = one_hot(stuff.y).type(torch.float32)
            prediction = graphclassifier(stuff)
            current_loss = loss(prediction, batch_labels)
            
            with torch.no_grad():
                real_classification = prediction.argmax(1)
                cst = real_classification[0] if np.unique(real_classification).shape[0] == 1 else None
                accuracy = (real_classification == stuff.y).type(torch.float32).mean(0)
                losses.append(current_loss.detach().numpy())
                epoch.set_description(f"Current acc {accuracy:.3f}, loss {current_loss.detach().numpy()}, {"" if cst is None else f"constant to {cst}"}")
            current_loss.backward()
            optim.step()
Current acc 0.775, loss 0.518334150314331, : 100%|█| 100/100 [01:12<00:00,  1.38
graphclassifier.eval()
test_stuff = next(iter(DataLoader(test_dataset,batch_size=len(test_dataset))))
prediction = graphclassifier(test_stuff)
verdad = one_hot(test_stuff.y).type(torch.float32)
loss(prediction, verdad), (prediction.argmax(1)==test_stuff.y).type(torch.float).abs().mean()
(tensor(0.5426, grad_fn=<DivBackward1>), tensor(0.6842))