Scalable graph machine learning: A mountain we climb?
Graph machine learning is still a relatively new and developing area of research, so you could be forgiven for a lack of familiarity with its place in deep learning.
But we’re fast discovering that interconnected problems like money laundering, media propagation and cybercrime are particularly well suited to graph analytics. We can draw powerful insights from connected data, like predicting relationships between entities that are missing from the data, or resolving unknown entities across different networks.
This is the key value of graph analytics: the ability to capture information from the complex relationships in the data. Instead of seeing data in silos, we can access a much bigger picture.
The application of this technology is limitless. Government departments in areas like law enforcement and health will almost certainly have connected data of some sort, as will industries like telecommunications, banking, insurance, marketing and social networks. There are countless large-scale, connected datasets waiting for the opportunity solve some of our greatest challenges.
However, when trying to tackle problems in the real world, we’re met with the challenge that the data is not static, but constantly evolving in complexity and size. The problem of scalability both fascinates and frustrates those of us working with graph algorithms.
What is graph machine learning?
When we say graph, we’re talking about a way of representing data as entities with connections between them. In mathematical terms, we call an entity a node or vertex, and a connection an edge. A collection of vertices V, together with a collection of edges E, form a graph G = (V, E).
Graph machine learning is a machine learning technique that can naturally learn and make predictions from graph-structured data. We can think of machine learning as a way of learning some transformation function; y = f(x), where x is a piece of data and y is something we want to predict about the data.
What do we mean by scalability?
A scalable mountain is a mountain people can climb. A scalable system is one that can handle growing demands. A scalable graph machine learning method should be a method that handles growing data sizes… and it also happens to be a huge mountain to climb.
In a more traditional deep learning pipeline, if we want to predict something about x, we only need information about x itself. But with the graph structure in mind, in order to predict something about x we potentially need to aggregate information from the entire dataset.
As a dataset gets larger and larger, suddenly we end up aggregating terabytes of data just to make a prediction about a single data point. Scalable? Not so much. Gigabytes of new data may present every day, and this is what makes scalability such an important consideration.
Delving deeper
Many graph machine learning methods are inherently transductive due to the way information is aggregated from the entire dataset. As opposed to inductive algorithms which try to discover a general rule for the world and take the data as a basis for making predictions for unseen data, transductive algorithms attempt to make better predictions for the unlabelled data in a dataset by not generalising a universal model.
Let’s break this down. Consider a node A in a graph that is connected to three other nodes B, C, and D:
If we weren’t applying any fancy graph methods, we would simply be learning a function that maps from the features of A to a more useful metric; e.g., a prediction we want to make about the node:
Consider that after we’ve trained the model, a new data point arrives sometime in the future that happens to be connected to our original node A. We ended up learning a function that doesn’t take this connection into account, so we are now uncertain whether the model we trained is valid for our new set of data.
So far, our understanding of graph algorithms suggests they are generally not very scalable, particularly if the algorithm is transductive in nature.
Introducing GraphSAGE
A typical way many algorithms try to tackle the scalability problem in graph machine learning is to incorporate sampling. The GraphSAGE1 algorithm does just that, taking a fixed-size sample of any given node’s local neighbourhood. This solves the problem of needing to aggregate information from across the entire dataset.
But what are we sacrificing by doing so?
1. First and most obviously, taking a sample means we’re taking an approximation of what the neighbourhood actually looks like. Depending on the size of the sample, it may be a good enough approximation for our purposes, but an approximation nonetheless.
2. We give up the chance for our model to learn something from how connected a node is. For GraphSAGE, a node with five neighbours looks exactly the same as a node with 50 neighbours since we always sample the same number of neighbours for each node.
3. Finally, we end up in a world where we could make different predictions about a node based on which neighbours we happened to sample at the time.
Depending on the problem we’d like to solve and what we know about our data, we can take a guess at how these issues may affect our results before deciding whether GraphSAGE is a suitable algorithm for our use-case.
I learned first-hand that when trying to apply graph machine learning techniques to identify fraudulent behaviour in a bitcoin blockchain dataset – which has millions of wallets (nodes) and billions of transactions (edges), making most graph machine learning methods infeasible – that the GraphSAGE algorithm provided a good approach for scaling graph machine learning (read StellarGraph’s full article on Medium).
Success, but not without challenges
GraphSAGE presents the neighbourhood sampling approach to overcome some of the challenges for scalability. Specifically, it:
● gives us a good approximation whilst bounding the input size for making predictions; and
● allows for an inductive algorithm.
This is a solid breakthrough, but doesn’t leave us without problems to be solved.
1. Efficient sampling is still difficult
In order to sample the neighbours of a node without introducing bias, you still need to iterate through all of them. This means although GraphSAGE does restrict the size of the input, the step required to populate the input involves looking through the entire graph which can be very costly.
2. Even with sampling, neighbourhood aggregation still aggregates A LOT of data
Even with a fixed neighbourhood size, applying this scheme recursively means that you get an exponential explosion of the neighbourhood size. For example, if we take 10 random neighbours each time but apply the aggregation over three recursive steps, this ultimately results in a neighbourhood size of 10³.
3. Distributed data introduces even more challenges for graph-based methods
Much of the big data ecosystem revolves around distributing data to enable parallelised workloads as well as the ability to scale out horizontally. However, naively distributing graph data introduces a significant problem as there is no guarantee that neighbourhood aggregation can be done without communication across the network. For graph-based methods, you’ll either be stuck paying the cost of shuffling data across the network, or miss out on the value of using big data technologies to enable your pipeline.
There are still mountains to climb and more exploration to be done to make scalable graph machine learning more practical. To learn more about graph machine learning visit stellargraph.io, or explore StellarGraph’s open source Machine Learning Python Library.
Article adapted from Kevin Jung’s ‘Scalable machine learning: a mountain we can climb?‘, published on Medium on 18 November 2019. This work is supported by CSIRO’s Data61, Australia’s leading digital research network.
References
1. Inductive Representation Learning on Large Graphs. W.L. Hamilton, R. Ying, and J. Leskovec. Neural Information Processing Systems (NIPS), 2017