The moment a three point shot is launched, there is a long pause of tension whether it effortlessly drops through the basket or chaotically bounce off the rim. Everyone instantly has a feeling on whether the shot will be good. Sometimes we are right.

We seek to teach a computer to predict whether a shot would be good or a miss. The computer would learn the same way we do, by watching lots of shots. (We don’t teach the computer about acceleration, gravity, spin . . . all the factors that actually affect the motion of the ball.)

This post describes a predictive model using deep learning for basketball trajectories. At the outset, we start with two key questions:

This post explains how we solved both these problems. Some specific sections include descriptions of the dataset, key parts of our model, the results, and how to generate new sequences of trajectories. For more detail on this work, please read our paper for the Large Scale Sports Analytics workshop at KDD 2016 and run the code from github.

Data

To study basketball trajectories, we needed a dataset of basketball trajectories. Using 631 NBA games from the beginning of the 2015-2016 season we identified over 20,000 three point shot attempts. More background on the NBA SportVu data is in our earlier work. The interactive plot shows made baskets in blue and missed baskets in red. Try moving the distance to the basket sliding bar from 0 to 13 feet. Can you gauge which shots are likely to be made? Our goal was to create a model to predict whether a shot would be made, using only data from earlier than 13 feet! To explore the further, use the sliding bar for the number of shots to select shots.

Recurrent Neural Networks

Recurrent neural networks (RNNs) are a class of deep learning models used to predict and generate sequences. They have had enormous success in textual analysis, voice recognition, and kaggle competitions. This approach is also capable of generating amazingly realistic sequences, such as for text and music (sample). Our work was principally inspired by a model developed by Graves to analyze handwriting. Graves used XY sequential data taken from handwriting on a smart whiteboard. He trained a RNN network on the XY data without any preprocessing. The model could predict the next letter or word and even generate sequences based on different initial starting points. I also used his model to generate two samples of different handwriting:

To some extent, handwriting is similar to the movement of players and the ball. So imagine applying this generative model to predict player/ball movement. You could create fictional scenarios with ball and player movement in the style of a player. For example, generate a penetration drive based on the style of Jeremy Lin. This is the eventual goal of our work.

Basketball RNN Model

The structure of our final model is shown below. The inputs consist of the XYZ position of the ball and the game clock over time. The final model used a Mixture Density Network (MDN). The model was a 2 layer LSTM network with 64 hidden units and implemented in Tensorflow.

The next few sections will walk through the major components of this model. This includes the MDN, model performance, and generating trajectories. For those seeking more background on RNNs work should read Olah, Britz, or Karpathy. A code snippet is show below for defining our LSTM cells. This post is staying at a high level, but for those interested in more code, please check out the github repo.

    with tf.name_scope("LSTM") as scope:
      cell = tf.nn.rnn_cell.LSTMCell(hidden_size,use_peepholes = True)
      cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)
      cell = tf.nn.rnn_cell.DropoutWrapper(cell,output_keep_prob=self.keep_prob)

Mixture Density Network

With sequential data, such as handwriting or basketball plays, there are often many possibilities at every time step. For example, with handwriting, the next step could be an upstroke or a downstroke. For a basketball play, the ball may stay with a player, it may arc up towards the basket, or may move at high speed to another player as a pass. Graves found that a mixture density network (MDN) is an effective method for accounting for these various outcomes. For the basketball trajectories, the intuition behind using an MDN is that if you follow a sequence of ball movement, at one time-step, the ball is at (1,1), next at (2,2). Now probably the next coordinate is (3,3). However, it could also be (3.1,3) or (2.9,2.8).

How would we discern between those cases?

A Mixture density network parametrizes a distribution in X, Y and Z coordinates. In our case, the distribution is a mixture of Gaussians. This distribution assigns a probability to every point in the three (X,Y and Z) space. Note that the distributions works over the offsets, not over the absolute coordinates. The four plots below illustrate this process, where the color indicates the shape of the distribution.

For our model, we created a distribution with a mixture of Gaussian distributions. Gaussians are defined by their mean and covariance matrix. Every time-step, the LSTM outputs parameters that define this Gaussian. Every Gaussian requires seven parameters:
- mean and variance in x (\(\mu_x \sigma_x\))
- mean and variance in y (\(\mu_y \sigma_y\))
- mean and variance in z (\(\mu_z \sigma_z\))
- correlation in xy plane (\(\rho_{xy}\))
In our case, we used 3 mixtures, this meant we needed (7*3) or 21 parameters. These parameters are seen in the RNN architecture shown above and in the equation below.

\[p(v|x) = \frac{1}{Z}exp(\frac{-1}{2} \begin{pmatrix} v_x - \mu_x \\ v_y - \mu_y \\ v_z - \mu_z \end{pmatrix} \begin{pmatrix} \sigma^2_x & \sigma_x \sigma_y \rho_{xy} & 0 \\ \sigma_x \sigma_y \rho_{xy} &\sigma^2_y & 0 \\ 0 &0 & \sigma^2_z \end{pmatrix}^{-1} \begin{pmatrix} v_x - \mu_x \\ v_y - \mu_y \\ v_z - \mu_z \end{pmatrix} ) \]

RNN Predictive Model Performance

Our RNN model predicts whether a basketball shot will be a make or a miss. We actually developed twelve models based on the distance of the ball from the basket (2-13 feet). For example, the first model, makes a prediction using trajectory data up until 2 feet away from the basket. This model naturally performs much better than the 13 foot RNN model, which must make predictions without any information closer than 13 feet from the basket.

To assess the performance of the models, we use Area Under the Curve (AUC), which is a typical metric for a binary classification problem. As the chart below shows, the RNN performed much better than the baseline generalized linear model (GLM) and gradient boosted machine (GBM) model. The GLM and GBM relied on feature engineering (including position, speed, distance, and angle to the basket). For example, at 9 feet, the RNN model’s performance was an AUC of 0.83 to predict a hit or miss. For comparison, the scores of the GLM (0.55) and a GBM (0.694) were lower.

The results show that RNNs are capable of learning non-linear tracking data in a basketball game. They do this by relying solely on the movement data without any need for creating features! We believe this will be enormously useful when applied to more complex basketball movements, such as identifying screens.

Generating Trajectories

The second question to ask is
Can we sample new trajectories from the MDN?
That answer is also yes.

Our model uses a distribution to help predict next time-step. If we sample from this distribution, we can generate the full trajectory. We also use a bias and a priming sequence to improve the quality of the generated trajectories. In the plot below, you can try and get a feel for the impact of using a bias and a priming sequence.

The bias uses a simple heuristic to bias the predictions towards the mode of the distribution. For example, in the first image, probable offsets in the z-plane range from -0.1 to -0.6 feet. An unfortunate draw at, say, -0.6 feet will influence the remaining trajectory, like above. A biased sample is a sample more close to the mode of the distribution. For example, a biased sample will not draw from the Gaussian between -0.1 and -0.6, but from a biased distribution between -0.3 and -0.4.

To understand the effect of bias, select a priming sequence of 5 and then try sliding the bias in the figure below. As the bias gets higher, it makes the sample less diverse and the trajectory becomes smoother.

A priming sequence uses several points of a true trajectory for the initial steps. These initial steps help to prime or warm up the LSTM, resulting in more realistic trajectories. In the figure below, select a bias of 0 and then try sliding the priming sequence. As the priming sequence gets higher, it makes the sample less noisy and smoother.

Next Steps

This work has shown the value of using a RNN architecture for modeling basketball trajectories. The key here is the RNN was capable of learning a non-linear system based on positional data. For a complete write-up, see our paper or run the code from github. There are several takeaways from this work.

First, we recognize that this is a toy problem to the extent that knowledge about basketball trajectories already exists with physics based models. Even so, we found basketball trajectories analogous to MNIST as being able to test/verify network architecture against a dataset that we understood.

Second, we were able to show that RNNs were capable of modeling the trajectory data. This will serve as a starting point for incorporating player movement into the model. We have already begun to use these architectures with the addition of player movement for prediction and play generation.

Finally, we hope our work will inspire others to try RNNs. By releasing the dataset we used along with the tensorflow code, we hope to get others to start playing around with RNNs.

Acknowlegements

Graves’ work was deeply inspirational – Alex Graves’ paper on Handwriting and his lecture at Oxford

We also found Hardmaru’s blog post useful in explaining MDNs as well as Andrej Karpathy’s blog on RNN’s

For a gentle introduction to RNNs, see Rajiv’s work on teaching an RNN to add and using tensorflow with shiny..

Rob has a series of posts on using deep learning for time series classification.

Our Github repo is here