Cohort Retention Analysis¶


In this project, we conduct a time-based cohort and retention analysis in python to examine how many customers are staying and how many are leaving in a given cohort over time. We will group customers into cohorts based on the time of acquisition and compute the cohort index which is the number of months since the customer was acquired. We will then compute the retention rates for each cohort over time.

Project Summary¶


Context¶


This is a transnational data set which contains all the transactions occurring between 01/12/2010 and 09/12/2011 for a UK-based online retail company. The company mainly sells unique all-occasion gifts. Many customers of the company are wholesalers.

The company wants to understand the behavior of their customers over time. They want to know how many of their customers remain active after acquisition. Hence we will run a cohort and retention analysis.

A cohort is a group of persons who have a common trait over an extended period of time. Examples include users who become customers at the same time or a graduating class of students. A cohort analysis can provide a great amount of information for the company and managers as to the behavior of the customers that was acquired or/and converted throughout a campaign. Companies use cohort analysis to understand the trends and patterns of customers over time and to tailor their offers of products and services to the identified cohorts.

Performing a Cohort retention analysis is a useful practice for reducing early customer churn.

Actions¶


Before running our analysis, we first explored our data. The data contains severals columns which includes the invoice number, stock code, description, quantity, invoice date, unit price, customer id and country. For the purpose of this analysis, we are interested in Invoice date and Customer ID. We dropped all records with missing customer ID and used the Invoice number to filter out all cancelled transactions.

We performed a cohort analysis by grouping the customers into cohorts based on the acquisition time which is the first time they made a purchase. We computed the cohort index which represent the number of months since the acquisitions time and create a cohort table.

We performed the retention analysis by computing the retention rate of each cohort. This is done by dividing the values in cohort table by the corresponding cohort size and converting the resulting values to percentages. We used a heatmap to visualize the result of the retention analysis.

Result¶


RetentionTableHeatmap

To interpret the result, we have 13 cohorts and 12 cohorts indexes. The retention rate values ranges from 0 to 100% where values closer to 0% bluish, values around 50% are grayish and values closer to 100% are reddish. For instance, the retention rate for December 2010 cohort on the 11th index is 50%. This retention rate means that 50% of the customers that were acquired on December 2010 made a purchase again 11 months later. In other words, 50% of customers acquired in December 2010 were still active 11 months later.


Importing Required Dependencies¶

In [ ]:
# import libraries 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns 
import datetime as dt

Data Loading¶

In [ ]:
# loading the data to a data frame
data = pd.read_excel("./data/Online Retail.xlsx")

Data Definition¶

  • InvoiceNo: Invoice number, a 6-digit integral number uniquely assigned to each transaction. If this code starts with letter 'c', it indicates a cancellation.
  • StockCode: Product (item) code, a 5-digit integral number uniquely assigned to each distinct product.
  • Description: Product (item) name.
  • Quantity: The quantities of each product (item) per transaction.
  • InvoiceDate: Invoice Date and time, the day and time when each transaction was generated.
  • UnitPrice: Unit price, Product price per unit in sterling.
  • CustomerID: Customer number, a 5-digit integral number uniquely assigned to each customer.
  • Country: Country name, the name of the country where each customer resides.
In [ ]:
# print the first 5 rows of the dataframe
data.head()
Out[ ]:
InvoiceNo StockCode Description Quantity InvoiceDate UnitPrice CustomerID Country
0 536365 85123A WHITE HANGING HEART T-LIGHT HOLDER 6 2010-12-01 08:26:00 2.55 17850.0 United Kingdom
1 536365 71053 WHITE METAL LANTERN 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom
2 536365 84406B CREAM CUPID HEARTS COAT HANGER 8 2010-12-01 08:26:00 2.75 17850.0 United Kingdom
3 536365 84029G KNITTED UNION FLAG HOT WATER BOTTLE 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom
4 536365 84029E RED WOOLLY HOTTIE WHITE HEART. 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom
In [ ]:
# print last 5 rows of the dataframe
data.tail()
Out[ ]:
InvoiceNo StockCode Description Quantity InvoiceDate UnitPrice CustomerID Country
541904 581587 22613 PACK OF 20 SPACEBOY NAPKINS 12 2011-12-09 12:50:00 0.85 12680.0 France
541905 581587 22899 CHILDREN'S APRON DOLLY GIRL 6 2011-12-09 12:50:00 2.10 12680.0 France
541906 581587 23254 CHILDRENS CUTLERY DOLLY GIRL 4 2011-12-09 12:50:00 4.15 12680.0 France
541907 581587 23255 CHILDRENS CUTLERY CIRCUS PARADE 4 2011-12-09 12:50:00 4.15 12680.0 France
541908 581587 22138 BAKING SET 9 PIECE RETROSPOT 3 2011-12-09 12:50:00 4.95 12680.0 France
In [ ]:
# number of rows and columns in the dataset
data.shape
Out[ ]:
(541909, 8)
In [ ]:
# getting some information about the data
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 541909 entries, 0 to 541908
Data columns (total 8 columns):
 #   Column       Non-Null Count   Dtype         
---  ------       --------------   -----         
 0   InvoiceNo    541909 non-null  object        
 1   StockCode    541909 non-null  object        
 2   Description  540455 non-null  object        
 3   Quantity     541909 non-null  int64         
 4   InvoiceDate  541909 non-null  datetime64[ns]
 5   UnitPrice    541909 non-null  float64       
 6   CustomerID   406829 non-null  float64       
 7   Country      541909 non-null  object        
dtypes: datetime64[ns](1), float64(2), int64(1), object(4)
memory usage: 33.1+ MB

We notice data columns are well formatted and there are some null values in the CustomerID and Description columns.

Data Cleaning¶

Let's check if there is any duplicated records

In [ ]:
data.duplicated().sum()
Out[ ]:
5268
In [ ]:
duplicated_records = data[data.duplicated()]
duplicated_records
Out[ ]:
InvoiceNo StockCode Description Quantity InvoiceDate UnitPrice CustomerID Country
517 536409 21866 UNION JACK FLAG LUGGAGE TAG 1 2010-12-01 11:45:00 1.25 17908.0 United Kingdom
527 536409 22866 HAND WARMER SCOTTY DOG DESIGN 1 2010-12-01 11:45:00 2.10 17908.0 United Kingdom
537 536409 22900 SET 2 TEA TOWELS I LOVE LONDON 1 2010-12-01 11:45:00 2.95 17908.0 United Kingdom
539 536409 22111 SCOTTIE DOG HOT WATER BOTTLE 1 2010-12-01 11:45:00 4.95 17908.0 United Kingdom
555 536412 22327 ROUND SNACK BOXES SET OF 4 SKULLS 1 2010-12-01 11:49:00 2.95 17920.0 United Kingdom
... ... ... ... ... ... ... ... ...
541675 581538 22068 BLACK PIRATE TREASURE CHEST 1 2011-12-09 11:34:00 0.39 14446.0 United Kingdom
541689 581538 23318 BOX OF 6 MINI VINTAGE CRACKERS 1 2011-12-09 11:34:00 2.49 14446.0 United Kingdom
541692 581538 22992 REVOLVER WOODEN RULER 1 2011-12-09 11:34:00 1.95 14446.0 United Kingdom
541699 581538 22694 WICKER STAR 1 2011-12-09 11:34:00 2.10 14446.0 United Kingdom
541701 581538 23343 JUMBO BAG VINTAGE CHRISTMAS 1 2011-12-09 11:34:00 2.08 14446.0 United Kingdom

5268 rows × 8 columns

In [ ]:
# Drop duplicated transactions
data = data.drop(duplicated_records.index, axis=0)

Let's check for missing values

In [ ]:
# checking for missing values
data.isnull().sum()
Out[ ]:
InvoiceNo           0
StockCode           0
Description      1454
Quantity            0
InvoiceDate         0
UnitPrice           0
CustomerID     135037
Country             0
dtype: int64
In [ ]:
# drop rows with no customer ID
data = data.dropna(subset=['CustomerID'])
data.isnull().sum()
Out[ ]:
InvoiceNo      0
StockCode      0
Description    0
Quantity       0
InvoiceDate    0
UnitPrice      0
CustomerID     0
Country        0
dtype: int64

Let's check for cancelled orders

In [ ]:
cancelled_records = data[data["InvoiceNo"].str.startswith("C").fillna(False)]
cancelled_records
Out[ ]:
InvoiceNo StockCode Description Quantity InvoiceDate UnitPrice CustomerID Country
141 C536379 D Discount -1 2010-12-01 09:41:00 27.50 14527.0 United Kingdom
154 C536383 35004C SET OF 3 COLOURED FLYING DUCKS -1 2010-12-01 09:49:00 4.65 15311.0 United Kingdom
235 C536391 22556 PLASTERS IN TIN CIRCUS PARADE -12 2010-12-01 10:24:00 1.65 17548.0 United Kingdom
236 C536391 21984 PACK OF 12 PINK PAISLEY TISSUES -24 2010-12-01 10:24:00 0.29 17548.0 United Kingdom
237 C536391 21983 PACK OF 12 BLUE PAISLEY TISSUES -24 2010-12-01 10:24:00 0.29 17548.0 United Kingdom
... ... ... ... ... ... ... ... ...
540449 C581490 23144 ZINC T-LIGHT HOLDER STARS SMALL -11 2011-12-09 09:57:00 0.83 14397.0 United Kingdom
541541 C581499 M Manual -1 2011-12-09 10:28:00 224.69 15498.0 United Kingdom
541715 C581568 21258 VICTORIAN SEWING BOX LARGE -5 2011-12-09 11:57:00 10.95 15311.0 United Kingdom
541716 C581569 84978 HANGING HEART JAR T-LIGHT HOLDER -1 2011-12-09 11:58:00 1.25 17315.0 United Kingdom
541717 C581569 20979 36 PENCILS TUBE RED RETROSPOT -5 2011-12-09 11:58:00 1.25 17315.0 United Kingdom

8872 rows × 8 columns

In [ ]:
# Exclude cancelled orders from the data
data = data.drop(cancelled_records.index, axis=0)

We noticed there are 5268 duplicated transactions, 135037 missing customer IDs and 8872 transactions which were cancelled. To obtain an accurate result from our cohort retention analysis we omitted all duplicates, null values and cancelled transactions from our records.

Cohort Analysis¶

Calculate the acquisition month (Cohort Month) of each customer as the minimum invoice date and assign it to the corresponding customer. Compute the cohort index and develop the cohort table.

In [ ]:
# create an invoice month
data['Invoice_Month'] = data['InvoiceDate'].dt.to_period('M')
data.head()                                                 
Out[ ]:
InvoiceNo StockCode Description Quantity InvoiceDate UnitPrice CustomerID Country Invoice_Month
0 536365 85123A WHITE HANGING HEART T-LIGHT HOLDER 6 2010-12-01 08:26:00 2.55 17850.0 United Kingdom 2010-12
1 536365 71053 WHITE METAL LANTERN 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom 2010-12
2 536365 84406B CREAM CUPID HEARTS COAT HANGER 8 2010-12-01 08:26:00 2.75 17850.0 United Kingdom 2010-12
3 536365 84029G KNITTED UNION FLAG HOT WATER BOTTLE 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom 2010-12
4 536365 84029E RED WOOLLY HOTTIE WHITE HEART. 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom 2010-12
In [ ]:
# define the acquisition month (Cohort Month) as the minimum invoice date for each transaction
data['Cohort_Month'] =  data.groupby('CustomerID')['Invoice_Month'].transform('min')
data.tail()
Out[ ]:
InvoiceNo StockCode Description Quantity InvoiceDate UnitPrice CustomerID Country Invoice_Month Cohort_Month
541904 581587 22613 PACK OF 20 SPACEBOY NAPKINS 12 2011-12-09 12:50:00 0.85 12680.0 France 2011-12 2011-08
541905 581587 22899 CHILDREN'S APRON DOLLY GIRL 6 2011-12-09 12:50:00 2.10 12680.0 France 2011-12 2011-08
541906 581587 23254 CHILDRENS CUTLERY DOLLY GIRL 4 2011-12-09 12:50:00 4.15 12680.0 France 2011-12 2011-08
541907 581587 23255 CHILDRENS CUTLERY CIRCUS PARADE 4 2011-12-09 12:50:00 4.15 12680.0 France 2011-12 2011-08
541908 581587 22138 BAKING SET 9 PIECE RETROSPOT 3 2011-12-09 12:50:00 4.95 12680.0 France 2011-12 2011-08
In [ ]:
# create a function to get the month and year
def get_month_year(df, column):
    month = df[column].dt.month
    year = df[column].dt.year
    return month, year 
In [ ]:
# get the month and year for our cohort and invoice columns
Invoice_month,Invoice_year =  get_month_year(data,'Invoice_Month')
Cohort_month,Cohort_year =  get_month_year(data,'Cohort_Month')

Cohort Index represent the number of months after acquisition date.

In [ ]:
# create a cohort index for each transaction
year_diff = Invoice_year - Cohort_year
month_diff = Invoice_month - Cohort_month
data['Cohort_Index'] = year_diff*12+month_diff
data.head()
Out[ ]:
InvoiceNo StockCode Description Quantity InvoiceDate UnitPrice CustomerID Country Invoice_Month Cohort_Month Cohort_Index
0 536365 85123A WHITE HANGING HEART T-LIGHT HOLDER 6 2010-12-01 08:26:00 2.55 17850.0 United Kingdom 2010-12 2010-12 0
1 536365 71053 WHITE METAL LANTERN 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom 2010-12 2010-12 0
2 536365 84406B CREAM CUPID HEARTS COAT HANGER 8 2010-12-01 08:26:00 2.75 17850.0 United Kingdom 2010-12 2010-12 0
3 536365 84029G KNITTED UNION FLAG HOT WATER BOTTLE 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom 2010-12 2010-12 0
4 536365 84029E RED WOOLLY HOTTIE WHITE HEART. 6 2010-12-01 08:26:00 3.39 17850.0 United Kingdom 2010-12 2010-12 0
In [ ]:
# calculate the number of unique customers in each Cohort Month and Cohort Index 
cohort_data = data.groupby(['Cohort_Month','Cohort_Index'])['CustomerID'].nunique().reset_index()
cohort_data.columns = ['Cohort_Month','Cohort_Index', 'Customers_Count']
cohort_data
Out[ ]:
Cohort_Month Cohort_Index Customers_Count
0 2010-12 0 885
1 2010-12 1 324
2 2010-12 2 286
3 2010-12 3 340
4 2010-12 4 321
... ... ... ...
86 2011-10 1 86
87 2011-10 2 41
88 2011-11 0 324
89 2011-11 1 36
90 2011-12 0 41

91 rows × 3 columns

In [ ]:
# create a cohort table using the pivot table function
cohort_table = cohort_data.pivot(index='Cohort_Month', columns=['Cohort_Index'],values='Customers_Count')
cohort_table
Out[ ]:
Cohort_Index 0 1 2 3 4 5 6 7 8 9 10 11 12
Cohort_Month
2010-12 885.0 324.0 286.0 340.0 321.0 352.0 321.0 309.0 313.0 350.0 331.0 445.0 235.0
2011-01 417.0 92.0 111.0 96.0 134.0 120.0 103.0 101.0 125.0 136.0 152.0 49.0 NaN
2011-02 380.0 71.0 71.0 108.0 103.0 94.0 96.0 106.0 94.0 116.0 26.0 NaN NaN
2011-03 452.0 68.0 114.0 90.0 101.0 76.0 121.0 104.0 126.0 39.0 NaN NaN NaN
2011-04 300.0 64.0 61.0 63.0 59.0 68.0 65.0 78.0 22.0 NaN NaN NaN NaN
2011-05 284.0 54.0 49.0 49.0 59.0 66.0 75.0 27.0 NaN NaN NaN NaN NaN
2011-06 242.0 42.0 38.0 64.0 56.0 81.0 23.0 NaN NaN NaN NaN NaN NaN
2011-07 188.0 34.0 39.0 42.0 51.0 21.0 NaN NaN NaN NaN NaN NaN NaN
2011-08 169.0 35.0 42.0 41.0 21.0 NaN NaN NaN NaN NaN NaN NaN NaN
2011-09 299.0 70.0 90.0 34.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN
2011-10 358.0 86.0 41.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2011-11 324.0 36.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2011-12 41.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

Retention Analysis¶

Calculate Business Metrics: Retention Rate

The first column of the cohort table represent the cohort size, the total number of customers the cohort. To calculate the retention rate, we will divide the values of each column in the cohort table by the cohort size.

In [ ]:
# Change the index format
cohort_table.index = cohort_table.index.strftime('%B %Y')

# define cohort size and create the retention table
cohort_size = cohort_table.iloc[:,0]
retention_table = cohort_table.divide(cohort_size,axis=0)
retention_table.round(3)*100
Out[ ]:
Cohort_Index 0 1 2 3 4 5 6 7 8 9 10 11 12
Cohort_Month
December 2010 100.0 36.6 32.3 38.4 36.3 39.8 36.3 34.9 35.4 39.5 37.4 50.3 26.6
January 2011 100.0 22.1 26.6 23.0 32.1 28.8 24.7 24.2 30.0 32.6 36.5 11.8 NaN
February 2011 100.0 18.7 18.7 28.4 27.1 24.7 25.3 27.9 24.7 30.5 6.8 NaN NaN
March 2011 100.0 15.0 25.2 19.9 22.3 16.8 26.8 23.0 27.9 8.6 NaN NaN NaN
April 2011 100.0 21.3 20.3 21.0 19.7 22.7 21.7 26.0 7.3 NaN NaN NaN NaN
May 2011 100.0 19.0 17.3 17.3 20.8 23.2 26.4 9.5 NaN NaN NaN NaN NaN
June 2011 100.0 17.4 15.7 26.4 23.1 33.5 9.5 NaN NaN NaN NaN NaN NaN
July 2011 100.0 18.1 20.7 22.3 27.1 11.2 NaN NaN NaN NaN NaN NaN NaN
August 2011 100.0 20.7 24.9 24.3 12.4 NaN NaN NaN NaN NaN NaN NaN NaN
September 2011 100.0 23.4 30.1 11.4 NaN NaN NaN NaN NaN NaN NaN NaN NaN
October 2011 100.0 24.0 11.5 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
November 2011 100.0 11.1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
December 2011 100.0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

Visualizing The Retention Rate

In [ ]:
# Visualize the retention table
plt.figure(figsize=(23,10))
plt.title('Customer Retention Analysis', fontsize=20)
sns.heatmap(retention_table,annot=True,cmap='coolwarm',fmt='.0%')
plt.savefig('cohort_plt.png')

Interpreting The Retention Rate¶

Using the heatmap, we are able to visualize the retention rate values and color-code them to see how they differ.This is the most effective way to analyze a cohort retention analysis result.

To interpret the result, we have 13 cohorts and 12 cohorts indexes. The retention rate values ranges from 0 to 100% where values closer to 0% bluish, values around 50% are grayish and values closer to 100% are reddish. For instance, the retention rate for December 2010 cohort on the 11th index is 50% which has a gray color. This retention rate means that 50% of the customers that were acquired on December 2010 made a purchase again 11 months later. In other words, 50% of customers acquired in December 2010 were still active 11 months later.