Home

Awesome

<p align="center"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/dwreeves/dbt_linreg/main/docs/src/img/dbt-linreg-banner-dark.png#readme-logo"> <img src="https://raw.githubusercontent.com/dwreeves/dbt_linreg/main/docs/src/img/dbt-linreg-banner-light.png#readme-logo" alt="dbt_linreg logo"> </picture> </p> <p align="center"> <em>Linear regression in any SQL dialect, powered by dbt.</em> </p> <p align="center"> <img src="https://github.com/dwreeves/dbt_linreg/workflows/tests/badge.svg" alt="Tests badge"> <img src="https://github.com/dwreeves/dbt_linreg/workflows/docs/badge.svg" alt="Docs badge"> </p>

Overview

dbt_linreg is an easy way to perform linear regression and ridge regression in SQL (Snowflake, DuckDB, and more) with OLS using dbt's Jinja2 templating.

Reasons to use dbt_linreg:

Installation

dbt-core >=1.2.0 is required to install dbt_linreg.

Add this the packages: list your dbt project's packages.yml:

  - package: "dwreeves/dbt_linreg"
    version: "0.2.6"

The full file will look something like this:

packages:
  # ...
  # Other packages here
  # ...
  - package: "dwreeves/dbt_linreg"
    version: "0.2.6"

Examples

Simple example

The following example runs a linear regression of 3 columns xa + xb + xc on y, using data in the dbt model named simple_matrix. It outputs the data in "long" format, and rounds the coefficients to 5 decimal points:

{{
  config(
    materialized="table"
  )
}}
select * from {{
  dbt_linreg.ols(
    table=ref('simple_matrix'),
    endog='y',
    exog=['xa', 'xb', 'xc'],
    format='long',
    format_options={'round': 5}
  )
}} as linreg

Output:

variable_namecoefficientstandard_errort_statistic
const10.00.004622163.27883
xa5.00.4622610.81639
xb7.00.4622615.14295
xc9.00.4622619.46951

Note: simple_matrix is one of the test cases, so you can try this yourself! Standard errors are constant across xa, xb, xc, because simple_matrix is orthonormal.

Complex example

The following hypothetical example shows multiple ridge regressions (one per product_id) on a table that is preprocessed substantially. After the fact, predictions are run, and the R-squared of each regression is calculated at the end.

This example shows that, although dbt_linreg does not implement everything for you, the OLS implementation does most of the hard work. This gives you the freedom to do things you've never been able to do before in SQL!

{{
  config(
    materialized="table"
  )
}}
with

preprocessed_data as (

  select
    product_id,
    price,
    log(price) as log_price,
    epoch(time) as t,
    sin(epoch(time)*pi()*2 / (60*60*24*365)) as sin_t,
    cos(epoch(time)*pi()*2 / (60*60*24*365)) as cos_t
  from
    {{ ref('prices') }}

),

preprocessed_and_normalized_data as (

  select
    product_id,
    price,
    log(price) as log_price,
    (time - avg(time) over ()) / (stddev(time) over ()) as t_norm,
    (sin_t - avg(sin_t) over ()) / (stddev(sin_t) over ()) as sin_t_norm,
    (cos_t - avg(cos_t) over ()) / (stddev(cos_t) over ()) as cos_t_norm
  from
    preprocessed_data

),

coefficients as (

    select * from {{
      dbt_linreg.ols(
        table='preprocessed_and_normalized_data',
        endog='log_price',
        exog=['t_norm', 'sin_t_norm', 'cos_t_norm'],
        group_by=['product_id'],
        alpha=0.0001
      )
    }}

),

predict as (

  select
    d.product_id,
    d.time,
    d.price,
    exp(
      c.const
      + d.t_norm * c.t_norm
      + d.sin_t_norm * c.sin_t_norm
      + d.cos_t_norm * sin_t_norm) as predicted_price
  from
    preprocessed_and_normalized_data as d
  join
    coefficients as c
  on
    d.product_id = c.product_id

)

select
  product_id,
  pow(corr(predicted_price, price), 2) as r_squared
from
  predict
group by
  product_id

Supported Databases

dbt_linreg should work with most SQL databases, but so far, testing has been done for the following database tools:

If dbt_linreg does not work in your database tool, please let me know in a bug report and I can make sure it is supported.

* Minimal support. Postgres is syntactically supported, but is not performant under certain circumstances.

API

The only function available in the public API is the dbt_linreg.ols() macro.

Using Python typing notation, the full API for dbt_linreg.ols() looks like this:

def ols(
    table: str,
    endog: str,
    exog: Union[str, list[str]],
    add_constant: bool = True,
    format: Literal['wide', 'long'] = 'wide',
    format_options: Optional[dict[str, Any]] = None,
    group_by: Optional[Union[str, list[str]]] = None,
    alpha: Optional[Union[float, list[float]]] = None,
    method: Literal['chol', 'fwl'] = 'chol',
    method_options: Optional[dict[str, Any]] = None
):
    ...

Where:

Formats and format options

Outputs can be returned either in format='long' or format='wide'.

(In the future, I might add one or two more formats, notably a summary table format.)

All formats have their own format options, which can be passed into the format_options= arg as a dict, e.g. format_options={'foo': 'bar'}.

Options for format='long'

These options are available for format='long' only when method='chol':

Options for format='wide'

Methods and method options

There are currently two valid methods for calculating regression coefficients:

chol method

πŸ‘ This is the suggested method (and the default) for calculating regressions!

This method calculates regression coefficients using the Moore-Penrose pseudo-inverse, and the inverse of X'X is calculated using Cholesky decomposition, hence it is referred to as chol.

Options for method='chol'

Specify these in a dict using the method_options= kwarg:

fwl method

This method is generally not recommended.

Simple univariate regression coefficients are simply covar_pop(y, x) / var_pop(x).

The multiple regression implementation uses a technique described in section 3.2.3 Multiple Regression from Simple Univariate Regression of TEoSL (source). Econometricians know this as the Frisch-Waugh-Lovell theorem, hence the method is referred to as fwl internally in the code base.

Ridge regression is implemented using the augmentation technique described in Exercise 12 of Chapter 3 of TEoSL (source).

There are a few reasons why this method is discouraged over the chol method:

So when should you use fwl? The main use case is in OLTP systems (e.g. Postgres) for unregularized coefficient estimation. Long story short, the chol method relies on subquery optimization to be more performant than fwl; however, OLTP systems do not benefit at all from subquery optimization. This means that fwl is slightly more performant in this context.

Notes

Possible future features

Some things I am thinking about working on down the line:

FAQ

How does this work?

See Methods and method options section for a full breakdown of each linear regression implementation.

All approaches were validated using Statsmodels sm.OLS(). Note that the ridge regression coefficients differ very slightly from Statsmodels's outputs for currently unknown reasons, but the coefficients are very close (I enforce a <0.01% deviation from Statsmodels's ridge regression coefficients in my integration tests).

BigQuery (or other database) has linear regression implemented natively. Why should I use dbt_linreg over that?

You don't have to use this. Most warehouses don't support multiple regression out of the box, so this satisfies a niche for those database tools which don't.

That said, even in BigQuery, it may be very useful to extract coefficients within a query instead of generating a separate MODEL object through a DDL statement, for a few reasons. Even in more black box predictive contexts, being able to predict in the same SELECT statement as training can be convenient. Additionally, BigQuery does not expose model coefficients to users, and this can be a dealbreaker in many contexts where you care about your coefficients as measurements, not as predictive model parameters. Lastly, group_by is akin to estimating parameters for multiple linear regressions at once.

Overall, I would say this is pretty different from what BigQuery's CREATE MODEL is doing; use whatever makes sense for your use case! But keep in mind that for large numbers of variables, a native implementation of linear regression will be noticeably more efficient than this implementation.

Why is L2 regularization / ridge regression supported, but not L1 regularization / LASSO supported?

There is no closed-form solution to L1 regularization, which makes it very very hard to add through raw SQL. L2 regularization has a closed-form solution and can be implemented using a pre-processing trick.

Is the group_by=[...] argument like categorical variables / one-hot encodings?

No. You should think of the group by more as a seemingly unrelated regressions implementation than as a categorical variable implementation. It's running multiple regressions and each individual partition is its own y vector and X matrix. This is not a replacement for dummy variables.

Why aren't categorical variables / one-hot encodings supported?

I opt to leave out dummy variable support because it's tricky, and I want to keep the API clean and mull on how to best implement that at the highest level.

Note that you couldn't simply add categorical variables in the same list as numeric variables because Jinja2 templating is not natively aware of the types you're feeding through it, nor does Jinja2 know the values that a string variable can take on. The way you would actually implement categorical variables is with group by trickery (i.e. center both y and X by categorical variable group means), although I am not sure how to do that efficiently for more than one categorical variable column.

If you'd like to regress on a categorical variable, for now you'll need to do your own feature engineering, e.g. (foo = 'bar')::int as foo_bar

Why are there no p-values?

This is planned for the future, so stay tuned! P-values would require a lookup on a dimension table, which is a significant amount of work to manage nicely, but I hope to get to it soon.

In the meanwhile, you can implement this yourself-- just create a dimension table that left joins a t-statistic on a half-open interval to lookup a p-value.

Trademark & Copyright

dbt is a trademark of dbt Labs.

This package is unaffiliated with dbt Labs.