Skip to content

Commit

Permalink
Merge pull request #1 from MaxGhenis/main
Browse files Browse the repository at this point in the history
Initial package
  • Loading branch information
MaxGhenis authored Oct 19, 2024
2 parents 02bf67d + bb60e5a commit 2353cd7
Show file tree
Hide file tree
Showing 22 changed files with 914 additions and 36 deletions.
36 changes: 0 additions & 36 deletions .github/workflows/blank.yml

This file was deleted.

35 changes: 35 additions & 0 deletions .github/workflows/ci-cd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Python package

on: [push, pull_request]

jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest black isort flake8
# Install PyTorch separately as it has special requirements
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
# Install the package in development mode with all dependencies
python -m pip install -e ".[dev]"
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,27 @@
# statsmodels-sgd

Reimplementation of statsmodels using stochastic gradient descent

Use just like statsmodels, but with stochastic gradient descent.

## Installation

```bash
pip install git+https://github.com/PolicyEngine/statsmodels-sgd.git
```

## Example

```python
import statsmodels_sgd.api as sm_sgd

# Fit OLS model

model = sm_sgd.OLS(n_features=X.shape[1])
model.fit(X, y)

# Fit Logit model

model = sm_sgd.Logit(n_features=X.shape[1])
model.fit(X, y)
```
101 changes: 101 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
[project]
name = "statsmodels-sgd"
version = "0.1.0"
description = "A statsmodels-like package using SGD with gradient clipping for differential privacy"
authors = [{name = "Max Ghenis", email = "max@policyengine.org"}]
license = {file = "LICENSE"}
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"numpy",
"pandas",
"statsmodels",
"torch",
]

[project.optional-dependencies]
dev = [
"pytest",
"pytest-cov",
"black",
"isort",
"flake8",
]

[build-system]
requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["statsmodels_sgd"]

[tool.poetry]
name = "statsmodels-sgd"
version = "0.1.0"
description = "A statsmodels-like package using SGD with gradient clipping for differential privacy"
authors = [{name = "Max Ghenis", email = "max@policyengine.org"}]
license = "MIT"
readme = "README.md"
packages = [{include = "statsmodels_sgd"}]

[tool.poetry.dependencies]
python = ">=3.8"
numpy = "*"
pandas = "*"
statsmodels = "*"
torch = "*"

[tool.poetry.dev-dependencies]
pytest = "*"
pytest-cov = "*"
black = "*"
isort = "*"
flake8 = "*"
2 changes: 2 additions & 0 deletions statsmodels_sgd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import api
from . import regression
5 changes: 5 additions & 0 deletions statsmodels_sgd/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .regression.linear_model import OLS
from .regression.discrete_model import Logit

# Make commonly used objects available at api level
__all__ = ["OLS", "Logit"]
61 changes: 61 additions & 0 deletions statsmodels_sgd/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch.nn as nn
import torch.optim as optim
import numpy as np
from .tools import (
add_constant,
calculate_standard_errors,
calculate_t_p_values,
)

import sys

print(f"Python executable: {sys.executable}")
print(f"Python version: {sys.version}")
print(f"Python path: {sys.path}")

try:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch installation path: {torch.__file__}")
except ImportError as e:
print(f"Error importing torch: {e}")
print("Installed packages:")
import subprocess

result = subprocess.run(
[sys.executable, "-m", "pip", "list"], capture_output=True, text=True
)
print(result.stdout)


class BaseModel(nn.Module):
def __init__(
self,
n_features,
learning_rate=0.01,
epochs=1000,
batch_size=32,
clip_value=1.0,
):
super().__init__()
self.linear = nn.Linear(n_features, 1)
self.learning_rate = learning_rate
self.epochs = epochs
self.batch_size = batch_size
self.clip_value = clip_value
self.results_ = None

def forward(self, x):
return self.linear(x)

def fit(self, X, y, sample_weight=None):
raise NotImplementedError("Subclasses must implement this method")

def predict(self, X):
raise NotImplementedError("Subclasses must implement this method")

def summary(self):
if self.results_ is None:
raise ValueError("Model has not been fit yet.")
return self.results_
5 changes: 5 additions & 0 deletions statsmodels_sgd/docs/_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
title: Statsmodels-SGD Documentation
author: Your Name
logo: logo.png
execute:
execute_notebooks: force
7 changes: 7 additions & 0 deletions statsmodels_sgd/docs/_toc.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
format: jb-book
root: index
chapters:
- file: installation
- file: ols-example
- file: logit-example
- file
Loading

0 comments on commit 2353cd7

Please sign in to comment.