Skip to content

ICL over Graphs : PRODIGY

In this note, I will cover the following paper "PRODIGY : Enabling in-context learning over graphs".

NOTE : Definition of In-context Learning (ICL) is already covered.

1. Abstract

  • ICL is the ability of a pretrained model to adapt to novel and diverse downstream tasks by conditioning on prompt examples, without optimizing any parameters.
  • The first pretraining framework that enables ICL over graphs - PRODIGY (Pretraining Over Diverse In-Context Graph Systems)

2. Introduction

  • ICL : Capability of a pretrained model to perform diverse tasks directly at the prediction time when prompted with just a few examples, without any model training or fine tuning.
  • Challenges :
    1. How to formulate node-, edge- and graph-level tasks over graphs with a unified task representation so that model performs diverse tasks without any fine-tuning or retraining.
    2. How to design model architecture and pre-training objectives that enables model to achieve ICL over graphs.
  • Existing Work :
    • Pre-training just learns good graph encoder and then fine-tuning is done for each downstream task.
    • Meta-learning generalizes across tasks within the same graph.
  • ICL over graph means generalizing across graphs and tasks without any fine-tuning or retraining.

prodigy

  • PRODIGY :
    1. It proposes prompt graph which unifies node-, edge- and graph-level tasks over graphs.
    2. For pre-training, it uses in-context pretraining objectives (Neighbor Matching & Multi Task) which is a self-supervised task.
    3. It also defines a GNN structure and an attention mechanism to communicate over prompt graph.

3. Few-shot Prompting

  • We define a graph as \(\mathcal{G} = (\mathcal{V},\mathcal{E}, \mathcal{R})\), where \(\mathcal{V},\mathcal{E}, \mathcal{R}\) represent the set of nodes, edges and relations. An edge \(e = (u,r,v) \in \mathcal{E}\) consists of a subject \(u \in \mathcal{V}\), a relation \(\mathcal{r} \in \mathcal{R}\) and an object \(v \in \mathcal{V}\).
  • Suppose, we have a m-way classification task with |\(\mathcal{Y}\)| = m classes and we define k-shot prompt.
  • prompt examples : \(\mathcal{S} = \{(x_i, y_i)\}_{i=1}^{m \cdot k}\) with k examples per class \(y \in \mathcal{Y}\).
  • query set : \(\mathcal{Q} = \{x_i\}_{i=1}^n\) for which we want to predict the labels for.

4. Prompt Graph

4.1. Data Graph

  • To generate data graph \(\mathcal{G}^D\), sample the k-hop neighbour of the input node set from the source graph \(\mathcal{G}\).
  • For node classification, the input node set is a singleton set of target node. For link prediction, it is a pair of nodes.

4.2. Task Graph

  • Task graph \(\mathcal{G}^T\) consists of data nodes (\(v_{x_i}\)) and label nodes (\(v_{y_i}\)).
  • The data graph \(\mathcal{G}_i^D\) is aggregated into a single node \(v_{x_i}\) and the label \(y_i\) is represented by a label node \(v_{y_i}\).
  • So, a task graph contains (mk + n) data nodes and m label nodes.
  • For query set, we add single directional edge from label node to query data node.
  • NOTE : This is done to avoid information leakage from query set to prompt set.
  • For prompt examples, we add bi-directional edge between label node and data node. The edge with true labels are marked \(T\), others are marked \(F\).

5. Pre-training : Message Passing Architecture

5.1. Data Graph Message Passing

  • Used GraphSAGE for data graph \(\mathcal{G}^D\) to learn embeddings \(E\) for each node. \(E \in \mathcal{R}^{|\mathcal{V^D}| \text{ x d}}\)
  • For node prediction : \(E_{v_{x_i}}\) = \(E_{v_i}\) (take the node embedding of node \(v_i\))
  • For link prediction : \(E_{v_{x_i}}\) = \(W^T(E_{v_{i1}} || E_{v_{i2}} || max(E_j \:\:\forall_{j \in \mathcal{G}^D})) + b\)
    • concatenate the node embeddings of two nodes with max-pooling of all the node embeddings in the data graph
    • Then project it back to \(d\)-dimensional embedding space
    • \(W \in \mathcal{R}^{\text{3d x d}}\) is a learnable weight matrix and \(b \in \mathcal{R}^d\) is a learnable bias vector.

5.2. Task Graph Message Passing

  • The embedding of label node \(v_{y_i}\) can either be initialized with random Gaussian or additional information available about the labels.
  • Each edge also has two binary features \(e_{ij}\) that indicate
    1. whether the edge comes from an example or a query, and
    2. the edge type of \(T\) or \(F\).
  • The GNN architecture uses attention mechanism using \(K\), \(Q\), \(V\) values and is similar to a transformer.

5.3. Prediction Read-out

  • Take classfication logits \(O_i\) by taking cosine similarity between query and label node embedings.

6. Pre-training : In-context pretraining objectives

6.1. Neighbor Matching

  • This is Self-supervised task.
  • We sample multiple subgraphs from \(\mathcal{G}_{\text{pretrain}}\) as local neighbourhood and we say a node belongs to the neighbourhood if it is in the sampled subgraph.
  • \(\mathtt{NM}_{k, m}\) is sampler which generates \(m\)-way neighbor matching with \(k\)-shot prompt examples and labels of the queries.
  • First, sample \(m\) nodes from \(\mathcal{G}_{\text{pretrain}}\) and each node corresponds to one class.
  • Then, sample \(k\) nodes from \(l\)-hop neighbourhood of each class node \(c_i\).
  • Lastly, sample \(\lceil \frac{n}{m} \rceil\) nodes from each class node \(c_i\). This will be the query set.

6.2. Multi-task

  • This is Supervised task.
  • If node/edge level labels are present for \(\mathcal{G}_{\text{pretrain}}\), then we can leverage them to construct pretraining task similar to neighbor matching.
  • Sample \(m\) labels from whole label set. Then sample \(k\) support exampls for each label class.
  • Lastly, sample \(\lceil \frac{n}{m} \rceil\) nodes from each label class to form the query set.

6.3. Prompt Graph Augmentation and Pretraining Loss

  • Prompt Graph Augmentation :
    1. Node Dropping : From the data graph of samples and queries, randomly drop few nodes from its \(l\)-hop neighborhood.
    2. Node Feature Masking : From the data graph of samples and queries, mask the features as zero vector for few nodes randomly.
  • Pretraining Loss :
    • \(\mathcal{L} = \underset{x_i \in \mathcal{Q}_{\mathtt{task}}}{\mathbb{E}} \mathtt{CE}(O_{\mathtt{task, i}}, y_{\mathtt{task, i}})\) where \(\mathtt{task}\) = \(\mathtt{NM}\) and \(\mathtt{MT}\).
    • This is cross-entropy loss between logits and true class label.

7. Experiments