Skip to content

Commit

Permalink
Add CLV Cumulative Transaction Utlity (pymc-labs#1076)
Browse files Browse the repository at this point in the history
* copy lifetimes func into utils.py

* copy lifetimes tests

* test dataset

* pred loop notebook testing

* linting checks and df_cum_transactions fixture

* linting checks and df_cum_transactions fixture

* TODOs

* array appends

* bg default param and xarray mean

* docstrings

* notebook cleanup
  • Loading branch information
ColtAllen authored Oct 5, 2024
1 parent 2187dd3 commit 2151aa7
Show file tree
Hide file tree
Showing 4 changed files with 538 additions and 77 deletions.
314 changes: 242 additions & 72 deletions docs/source/notebooks/clv/dev/utilities.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"id": "435ed203-5c3c-4efc-93d1-abac66ce7187",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from pymc_marketing.clv import utils\n",
"from pymc_marketing.clv import ParetoNBDModel\n",
"from pymc_marketing.prior import Prior\n",
"\n",
"import pytensor\n",
"\n",
"import pandas as pd"
"#set flag to fix open issue\n",
"pytensor.config.cxx = '/usr/bin/clang++'"
]
},
{
Expand Down Expand Up @@ -471,7 +478,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/coltallen/Projects/pymc-marketing/pymc_marketing/clv/utils.py:698: UserWarning: RFM score will not exceed 2 for f_quartile. Specify a custom segment_config\n",
"/Users/coltallen/Projects/pymc-marketing/pymc_marketing/clv/utils.py:707: UserWarning: RFM score will not exceed 2 for f_quartile. Specify a custom segment_config\n",
" warnings.warn(\n"
]
}
Expand All @@ -488,12 +495,70 @@
")"
]
},
{
"cell_type": "markdown",
"id": "509f8d13-de5b-4a24-a468-a757888088f1",
"metadata": {},
"source": [
"`_expected_cumulative_transactions` is a utility function for creating cumulative plots over time:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "932ac4e5-361e-42fa-97d3-d8e508128944",
"execution_count": 8,
"id": "b320a25b-b449-4c28-ac36-4a9ca573403a",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2b3cf0c8e98407c90efa488a8db59da",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Output()"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
"</pre>\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
Expand All @@ -515,107 +580,212 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>customer_id</th>\n",
" <th>frequency</th>\n",
" <th>recency</th>\n",
" <th>monetary_value</th>\n",
" <th>rfm_score</th>\n",
" <th>segment</th>\n",
" <th>actual</th>\n",
" <th>predicted</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>1.5</td>\n",
" <td>321</td>\n",
" <td>Other</td>\n",
" <td>0</td>\n",
" <td>4.215266</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>5.0</td>\n",
" <td>2.0</td>\n",
" <td>111</td>\n",
" <td>Inactive Customer</td>\n",
" <td>19</td>\n",
" <td>16.569583</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>4.0</td>\n",
" <td>4.5</td>\n",
" <td>122</td>\n",
" <td>At Risk Customer</td>\n",
" <td>42</td>\n",
" <td>37.214571</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>7.0</td>\n",
" <td>324</td>\n",
" <td>Top Spender</td>\n",
" <td>81</td>\n",
" <td>66.721456</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>1.0</td>\n",
" <td>3.0</td>\n",
" <td>12.0</td>\n",
" <td>214</td>\n",
" <td>At Risk Customer</td>\n",
" <td>119</td>\n",
" <td>105.392417</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>6</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>313</td>\n",
" <td>Top Spender</td>\n",
" <td>192</td>\n",
" <td>153.733780</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>261</td>\n",
" <td>210.405989</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>351</td>\n",
" <td>275.417324</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>428</td>\n",
" <td>349.069329</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>504</td>\n",
" <td>431.079058</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>610</td>\n",
" <td>520.409690</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>733</td>\n",
" <td>616.028740</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>828</td>\n",
" <td>712.386501</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>914</td>\n",
" <td>805.471057</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>1005</td>\n",
" <td>895.569697</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>1078</td>\n",
" <td>982.929921</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>1149</td>\n",
" <td>1067.766834</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>1222</td>\n",
" <td>1150.268848</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>1286</td>\n",
" <td>1230.602131</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>1359</td>\n",
" <td>1308.914149</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>1414</td>\n",
" <td>1385.336498</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>1484</td>\n",
" <td>1459.987206</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>1517</td>\n",
" <td>1532.972626</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>1573</td>\n",
" <td>1604.388989</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>1672</td>\n",
" <td>1674.323708</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" customer_id frequency recency monetary_value rfm_score \\\n",
"0 1 2.0 0.0 1.5 321 \n",
"1 2 1.0 5.0 2.0 111 \n",
"2 3 2.0 4.0 4.5 122 \n",
"3 4 2.0 0.0 7.0 324 \n",
"4 5 1.0 3.0 12.0 214 \n",
"5 6 1.0 0.0 5.0 313 \n",
"\n",
" segment \n",
"0 Other \n",
"1 Inactive Customer \n",
"2 At Risk Customer \n",
"3 Top Spender \n",
"4 At Risk Customer \n",
"5 Top Spender "
" actual predicted\n",
"0 0 4.215266\n",
"1 19 16.569583\n",
"2 42 37.214571\n",
"3 81 66.721456\n",
"4 119 105.392417\n",
"5 192 153.733780\n",
"6 261 210.405989\n",
"7 351 275.417324\n",
"8 428 349.069329\n",
"9 504 431.079058\n",
"10 610 520.409690\n",
"11 733 616.028740\n",
"12 828 712.386501\n",
"13 914 805.471057\n",
"14 1005 895.569697\n",
"15 1078 982.929921\n",
"16 1149 1067.766834\n",
"17 1222 1150.268848\n",
"18 1286 1230.602131\n",
"19 1359 1308.914149\n",
"20 1414 1385.336498\n",
"21 1484 1459.987206\n",
"22 1517 1532.972626\n",
"23 1573 1604.388989\n",
"24 1672 1674.323708"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"segments"
"url_cdnow = \"https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/cdnow_transactions.csv\"\n",
"raw_trans = pd.read_csv(url_cdnow)\n",
"\n",
"rfm_data = utils.rfm_summary(\n",
" raw_trans, \n",
" customer_id_col = \"id\", \n",
" datetime_col = \"date\", \n",
" datetime_format = \"%Y%m%d\",\n",
" time_unit = \"D\",\n",
" observation_period_end = \"19970930\",\n",
" time_scaler = 7,\n",
")\n",
"\n",
"model_config = {\n",
" \"r_prior\": Prior(\"HalfFlat\"),\n",
" \"alpha_prior\": Prior(\"HalfFlat\"),\n",
" \"s_prior\": Prior(\"HalfFlat\"),\n",
" \"beta_prior\": Prior(\"HalfFlat\"),\n",
"}\n",
"\n",
"pnbd = ParetoNBDModel(data=rfm_data,model_config=model_config)\n",
"\n",
"pnbd.fit()\n",
"\n",
"df_cum = utils._expected_cumulative_transactions(\n",
" model=pnbd,\n",
" transactions=raw_trans,\n",
" customer_id_col=\"id\",\n",
" datetime_col=\"date\",\n",
" t=25*7,\n",
" datetime_format=\"%Y%m%d\",\n",
" time_unit=\"D\",\n",
" time_scaler= 7,\n",
")\n",
"\n",
"df_cum"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "143761be-434a-459a-97ef-7c3f62e48704",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def expected_purchases_new_customer(
self,
data: pd.DataFrame | None = None,
*,
t: np.ndarray | pd.Series,
t: int | np.ndarray | pd.Series | None = None,
) -> xarray.DataArray:
r"""Compute the expected number of purchases for a new customer across *t* time periods.
Expand Down
Loading

0 comments on commit 2151aa7

Please sign in to comment.