Customer matching using random forest

Introduction

Measuring treatment effect in data contexts where the response was already measured without an experimental design usually requires matching to control for confounding effects.

Here I outline matching using random forests to improve on performance over genetic matching while maintaining reasonable matching quality

Why?

In most cases the first line of attack would be matching using propensity scoring. This is easily done using a GLM logit model with a package like MatchIt in R.

Because this matches on a 1D metric only people often turn to genetic matching for its great quality of matching. This is slow however and production code can end up operating at too high a cost.

How?

Random forest matching is similar to genetic algorithms because people that are matched together generally have similar traits over many different variables.

The way it manages this is by training a propensity model by growing decision trees. After the forest has been trained we can look at the terminal leaves in all of the trees and see how often the samples ended up in the same predictive nodes. Each node uses different variables. The average of these coocurrence matrices is the proximity matrix used for matching.

Comparing run times - Random Forests VS Genetic matching

Implementation of random forest matching

There are many different ways to perform matching using proximity matrices. Using the base randomForest package you can specify hyper parameters for the model to keep it’s proximity matrices.

For this blog I decided to define this process using the ranger package instead. The main reason is that the ranger package is better optimized for big data and multi core processing.

Packages

To run this code you will need the following packages:

# install.packages("ranger")
library(ranger)

# install.packages("ROSE")
library(ROSE)
## Loaded ROSE 0.0-3
# install.packages("devtools")
# devtools::install_github("sipemu/similarity")
library(Similarity)

library(dplyr)

Thin feature space

For this specific example our data contains 35 principal components generated by a combination of PCA and MCA models. We will fit our model on the top 5 components:

rf_data <-
  rf_data %>% 
  select(1:10) %>% 
  mutate(vitality = vitality %>% factor)

rf_data <- 
  rf_data %>%
  setNames( c("customer_no","vitality","spend_month_2013_10_01","spend_month_2013_11_01","spend_month_2013_12_01","Dim_1","Dim_2", "Dim_3", "Dim_4", "Dim_5" )) 

rf_data %>% glimpse
## Observations: 456,366
## Variables: 10
## $ customer_no            <chr> "c51ce410c124a10e0db5e4b97fc2af39", "28...
## $ vitality               <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
## $ spend_month_2013_10_01 <dbl> 12281.10, 0.00, 1585.58, 5314.16, 0.00,...
## $ spend_month_2013_11_01 <dbl> 12137.86, 0.00, 0.00, 8136.73, 286.08, ...
## $ spend_month_2013_12_01 <dbl> 12769.57, 541.62, 5609.84, 7454.51, 114...
## $ Dim_1                  <dbl> 2.206990e+00, -9.920552e-01, 3.954618e-...
## $ Dim_2                  <dbl> 0.34242884, -2.20506687, -0.49166710, 0...
## $ Dim_3                  <dbl> 0.34804986, -1.95304588, 0.03740382, -0...
## $ Dim_4                  <dbl> -1.696023886, -1.385098798, 1.309264779...
## $ Dim_5                  <dbl> -0.30175202, -0.47332798, -1.47041980, ...

Fit random forest model

In this example I used the ROSE sampling strategy to achieve class balance because the conversion response was heavily under represented.

This class imbalance will skew your GLM or random forest model outcome because the entropy or loss function won’t be easily optimized to predict the conversion cases; the false positives from all the counter factual responses wouldn’t make sense and the model will never try to guess a conversion.

Don’t simply copy paste this methodology! ROSE creates synthetic samples! Understand what you are matching and how you are sampling!

Okay here we go:

rf_data <- ROSE(vitality~., data = rf_data %>% select(-customer_no), seed = 8020)$data 
rf_model <- ranger(formula = vitality~. ,data = rf_data,write.forest = T,classification = T)
## Growing trees.. Progress: 7%. Estimated remaining time: 12 minutes, 15 seconds.
## Growing trees.. Progress: 17%. Estimated remaining time: 6 minutes, 39 seconds.
## Growing trees.. Progress: 26%. Estimated remaining time: 5 minutes, 34 seconds.
## Growing trees.. Progress: 32%. Estimated remaining time: 5 minutes, 9 seconds.
## Growing trees.. Progress: 40%. Estimated remaining time: 4 minutes, 33 seconds.
## Growing trees.. Progress: 48%. Estimated remaining time: 3 minutes, 49 seconds.
## Growing trees.. Progress: 56%. Estimated remaining time: 3 minutes, 10 seconds.
## Growing trees.. Progress: 64%. Estimated remaining time: 2 minutes, 32 seconds.
## Growing trees.. Progress: 72%. Estimated remaining time: 1 minute, 58 seconds.
## Growing trees.. Progress: 80%. Estimated remaining time: 1 minute, 24 seconds.
## Growing trees.. Progress: 88%. Estimated remaining time: 48 seconds.
## Growing trees.. Progress: 97%. Estimated remaining time: 11 seconds.

Calculate proximity matrices

In this case I matches one-to-one. This is important; sometimes it makes more sense to match one-to-many for example.

Make sure that you record the weights of your matches so that you can weight your aggregations of the response you are trying to compare between these two groups - that is beyond the scope of this blog.

I take the first 1000 elements to speed up the runtime

treatment_data <- rf_data %>% filter(vitality == 1) %>% .[1:1000,]
control_data <- rf_data %>% filter(vitality == 0) %>% .[1:1000,]

prox_matrix <- 
    Similarity::proximityMatrixRanger(x = treatment_data %>% select(-vitality),y = control_data %>% select(-vitality),rf = rf_model)

Here each row will represent a treatment response while each column will represent a controll response.

A 1 row example:

prox_matrix[816,]
##    [1] 0.002 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [12] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [23] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [34] 0.000 0.000 0.000 0.002 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [45] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [56] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [67] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [78] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##   [89] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [100] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [111] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [122] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [133] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [144] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [155] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [166] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [177] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [188] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [199] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [210] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [221] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [232] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.002 0.000 0.000
##  [243] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [254] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [265] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [276] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [287] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.002
##  [298] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [309] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [320] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [331] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [342] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [353] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [364] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [375] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [386] 0.000 0.000 0.000 0.000 0.000 0.002 0.000 0.000 0.000 0.000 0.000
##  [397] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [408] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [419] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [430] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [441] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [452] 0.002 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [463] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [474] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [485] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [496] 0.000 0.000 0.000 0.000 0.000 0.000 0.002 0.000 0.000 0.000 0.000
##  [507] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [518] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [529] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [540] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [551] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [562] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [573] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [584] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [595] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [606] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [617] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [628] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [639] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [650] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [661] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [672] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [683] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [694] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [705] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [716] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [727] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [738] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [749] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [760] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.006
##  [771] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [782] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [793] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [804] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [815] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [826] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [837] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [848] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [859] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [870] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [881] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [892] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [903] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [914] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [925] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [936] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [947] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [958] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [969] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [980] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000
##  [991] 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.002
# as.matrix(prox_matrix) > 0
# (as.matrix(prox_matrix) > 0) %>% as.vector() %>% which(.)

The values in this matrix represent the proximity score of the {x,y} pair.

In the simple one-to-one case the argmax of each row represents the row_id match for each treatment response ID

Get matches

Getting the matches is simple since you can just grab the ID’s for each response:

match_id <- prox_matrix %>% apply(MARGIN = 1,FUN = which.max)

matches_out <- 
  tibble(
    customer_no =   MFAData %>% 
      mutate(vitality = factor(vitality)) %>% 
      filter(vitality==1) %>% 
      pull(customer_no) %>% .[1:1000], 
    match = MFAData %>% 
      mutate(vitality = factor(vitality)) %>% 
      filter(vitality==0) %>% 
      pull(customer_no) %>% .[1:1000] %>% .[match_id],
      weights = 1
    )

matches_out
## # A tibble: 1,000 x 3
##    customer_no                      match                          weights
##    <chr>                            <chr>                            <dbl>
##  1 1bc0249a6412ef49b07fe6f62e6dc8de 12b5fc7904a4ed8e390a03143b2cc…    1.00
##  2 4e87337f366f72daa424dae11df0538c 22785dd2577be2ce28ef79febe80d…    1.00
##  3 f0b1d5879866f2c2eba77f39993d1184 ca47d49b363ef094ba80bd0b8c4e0…    1.00
##  4 20546457187cf3d52ea86538403e47cc 671792587502028b6cd4be7c4d662…    1.00
##  5 cfc5d9422f0c8f8ad796711102dbe32b 1b43914f3e2e1f22b1090eb86d69f…    1.00
##  6 516341c3e8f4543c8d465b0c514a6f92 7d4ce676f04524dadbb2f1565c871…    1.00
##  7 100d5d9191f185eeb98d6e291756954a 0a4dc6dae338c9cb08947c07581f7…    1.00
##  8 7364e0bb7f15ebfbc9e12d5b13f51a02 da0dba87d95286d836e37ca60ab1e…    1.00
##  9 764f9642ebf04622c53ebc366a68c0a7 fc8ce6292e51ac8214f544324c56d…    1.00
## 10 0634ac7160d3fd64ddf19bdcbdd401bb 2cb2cdcca9bdd55a897d897ac67f7…    1.00
## # ... with 990 more rows

Here I just created a tibble where each customer ID has a corresponding match customer ID.

Special care must be taken however; if a row has only zero values it is important to handle this edge case. A conscious decision must be made to assign at random or to discard etc.

I assign all of them a weight of 1 since we matched one-to-one; make sure to assign the appropriate weight should you match one-to-many for example (read documentation of the Similarity::proximityMatrixRanger function)