-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from MaxGhenis/main
Initial package
- Loading branch information
Showing
22 changed files
with
914 additions
and
36 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
name: Python package | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.12"] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install pytest black isort flake8 | ||
# Install PyTorch separately as it has special requirements | ||
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
# Install the package in development mode with all dependencies | ||
python -m pip install -e ".[dev]" | ||
- name: Lint with flake8 | ||
run: | | ||
# stop the build if there are Python syntax errors or undefined names | ||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics | ||
# exit-zero treats all errors as warnings | ||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics | ||
- name: Test with pytest | ||
run: | | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,27 @@ | ||
# statsmodels-sgd | ||
|
||
Reimplementation of statsmodels using stochastic gradient descent | ||
|
||
Use just like statsmodels, but with stochastic gradient descent. | ||
|
||
## Installation | ||
|
||
```bash | ||
pip install git+https://github.com/PolicyEngine/statsmodels-sgd.git | ||
``` | ||
|
||
## Example | ||
|
||
```python | ||
import statsmodels_sgd.api as sm_sgd | ||
|
||
# Fit OLS model | ||
|
||
model = sm_sgd.OLS(n_features=X.shape[1]) | ||
model.fit(X, y) | ||
|
||
# Fit Logit model | ||
|
||
model = sm_sgd.Logit(n_features=X.shape[1]) | ||
model.fit(X, y) | ||
``` |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
[project] | ||
name = "statsmodels-sgd" | ||
version = "0.1.0" | ||
description = "A statsmodels-like package using SGD with gradient clipping for differential privacy" | ||
authors = [{name = "Max Ghenis", email = "max@policyengine.org"}] | ||
license = {file = "LICENSE"} | ||
readme = "README.md" | ||
requires-python = ">=3.8" | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
] | ||
dependencies = [ | ||
"numpy", | ||
"pandas", | ||
"statsmodels", | ||
"torch", | ||
] | ||
|
||
[project.optional-dependencies] | ||
dev = [ | ||
"pytest", | ||
"pytest-cov", | ||
"black", | ||
"isort", | ||
"flake8", | ||
] | ||
|
||
[build-system] | ||
requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools] | ||
packages = ["statsmodels_sgd"] | ||
|
||
[tool.poetry] | ||
name = "statsmodels-sgd" | ||
version = "0.1.0" | ||
description = "A statsmodels-like package using SGD with gradient clipping for differential privacy" | ||
authors = [{name = "Max Ghenis", email = "max@policyengine.org"}] | ||
license = "MIT" | ||
readme = "README.md" | ||
packages = [{include = "statsmodels_sgd"}] | ||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.8" | ||
numpy = "*" | ||
pandas = "*" | ||
statsmodels = "*" | ||
torch = "*" | ||
|
||
[tool.poetry.dev-dependencies] | ||
pytest = "*" | ||
pytest-cov = "*" | ||
black = "*" | ||
isort = "*" | ||
flake8 = "*" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from . import api | ||
from . import regression |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .regression.linear_model import OLS | ||
from .regression.discrete_model import Logit | ||
|
||
# Make commonly used objects available at api level | ||
__all__ = ["OLS", "Logit"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import numpy as np | ||
from .tools import ( | ||
add_constant, | ||
calculate_standard_errors, | ||
calculate_t_p_values, | ||
) | ||
|
||
import sys | ||
|
||
print(f"Python executable: {sys.executable}") | ||
print(f"Python version: {sys.version}") | ||
print(f"Python path: {sys.path}") | ||
|
||
try: | ||
import torch | ||
|
||
print(f"PyTorch version: {torch.__version__}") | ||
print(f"PyTorch installation path: {torch.__file__}") | ||
except ImportError as e: | ||
print(f"Error importing torch: {e}") | ||
print("Installed packages:") | ||
import subprocess | ||
|
||
result = subprocess.run( | ||
[sys.executable, "-m", "pip", "list"], capture_output=True, text=True | ||
) | ||
print(result.stdout) | ||
|
||
|
||
class BaseModel(nn.Module): | ||
def __init__( | ||
self, | ||
n_features, | ||
learning_rate=0.01, | ||
epochs=1000, | ||
batch_size=32, | ||
clip_value=1.0, | ||
): | ||
super().__init__() | ||
self.linear = nn.Linear(n_features, 1) | ||
self.learning_rate = learning_rate | ||
self.epochs = epochs | ||
self.batch_size = batch_size | ||
self.clip_value = clip_value | ||
self.results_ = None | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
def fit(self, X, y, sample_weight=None): | ||
raise NotImplementedError("Subclasses must implement this method") | ||
|
||
def predict(self, X): | ||
raise NotImplementedError("Subclasses must implement this method") | ||
|
||
def summary(self): | ||
if self.results_ is None: | ||
raise ValueError("Model has not been fit yet.") | ||
return self.results_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
title: Statsmodels-SGD Documentation | ||
author: Your Name | ||
logo: logo.png | ||
execute: | ||
execute_notebooks: force |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
format: jb-book | ||
root: index | ||
chapters: | ||
- file: installation | ||
- file: ols-example | ||
- file: logit-example | ||
- file |
Oops, something went wrong.