Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic Attention attribution #148

Merged
merged 35 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b9ccbf2
added jupyterlab dependency (for easier testing)
lsickert Sep 23, 2022
08d5dbd
initial commit attention methods, added output_attentions parameter t…
lsickert Oct 17, 2022
57ae54a
Merge branch 'main' into attention-attribution
lsickert Oct 19, 2022
544321c
added basic attention method stubs\n added attention method registry
lsickert Oct 24, 2022
d0e859f
reverted changes to output generation (forward pass done inside attri…
lsickert Nov 21, 2022
a2a2021
first working version of basic attention methods
lsickert Nov 21, 2022
7b3c4fd
fixed rounding of values in cli output
lsickert Nov 22, 2022
13fc9f3
added documentation to most methods and generalized functions
lsickert Nov 22, 2022
bb13036
Merge branch 'main' into attention-attribution\n\nNeeded to downgrade…
lsickert Nov 23, 2022
9340343
removed python 3.11 build target
lsickert Nov 24, 2022
3cfd706
fix safety warnings
lsickert Nov 24, 2022
2765d63
set correct python version in pyproject.toml
lsickert Nov 25, 2022
4dd442f
regenerated requirements without 3.11
lsickert Nov 25, 2022
b14407c
Merge branch 'main' into attention-attribution, quick fix for mps issue
lsickert Dec 9, 2022
6a72166
Merge branch 'main' into attention-attribution
lsickert Dec 12, 2022
6535b09
merge branch 'main' into attention-attribution
lsickert Jan 2, 2023
624435e
update deps after merge
lsickert Jan 2, 2023
06f89a8
include 3.11 as build target
lsickert Jan 2, 2023
7bcbe92
fix different attribution_step argument formatting
lsickert Jan 2, 2023
b2fc73c
added basic decoder-only support
lsickert Jan 2, 2023
b044b4c
fixed output error for decoder only models
lsickert Jan 3, 2023
8c344b7
removed unnecessary convergence delta references in attention attribu…
lsickert Jan 3, 2023
f51cf25
allow negative indices when selecting a specific attention head for a…
lsickert Jan 4, 2023
c6a9e70
added missing negation to head checking
lsickert Jan 4, 2023
6c9cfae
fixed last_layer_attention attribution
lsickert Jan 4, 2023
b78bcc1
use custom format_attribute_args function for attention methods
lsickert Jan 9, 2023
d27f1c3
always use decoder_input_embeds in forward output
lsickert Jan 9, 2023
cacaa31
reworked LastLayerAttention to work with any single layer and allow a…
lsickert Jan 9, 2023
a8d5264
Minor bugfixes and version bumps
gsarti Jan 10, 2023
966f63c
Generalized attention attribution
gsarti Jan 10, 2023
1301a02
updated documentation and added 'min' aggregation function
lsickert Jan 13, 2023
914ee8f
Tests, typing fix, additional checks
gsarti Jan 14, 2023
7c825ad
Fix style
gsarti Jan 14, 2023
f6f0a64
added tests for attention utils
lsickert Jan 15, 2023
f40f63b
classmethod -> staticmethod where possible
gsarti Jan 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 161 additions & 14 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# Created by https://www.toptal.com/developers/gitignore/api/osx,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks
# Edit at https://www.toptal.com/developers/gitignore?templates=osx,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks

# Created by https://www.gitignore.io/api/osx,python,pycharm,windows,visualstudio,visualstudiocode
# Edit at https://www.gitignore.io/?templates=osx,python,pycharm,windows,visualstudio,visualstudiocode
### JupyterNotebooks ###
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/

.ipynb_checkpoints
*/.ipynb_checkpoints/*

# IPython
profile_default/
ipython_config.py

# Remove previous ipynb_checkpoints
# git rm -r .ipynb_checkpoints/

### OSX ###
# General
Expand Down Expand Up @@ -31,7 +44,7 @@ Temporary Items
.apdisk

### PyCharm ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
Expand All @@ -41,6 +54,9 @@ Temporary Items
.idea/**/dictionaries
.idea/**/shelf

# AWS User-specific
.idea/**/aws.xml

# Generated files
.idea/**/contentModel.xml

Expand All @@ -61,6 +77,9 @@ Temporary Items
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
Expand Down Expand Up @@ -88,6 +107,9 @@ atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml

# SonarLint plugin
.idea/sonarlint/

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
Expand All @@ -109,15 +131,31 @@ fabric.properties
# *.ipr

# Sonarlint plugin
# https://plugins.jetbrains.com/plugin/7973-sonarlint
.idea/**/sonarlint/

# SonarQube Plugin
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
.idea/**/sonarIssues.xml

# Markdown Navigator plugin
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
.idea/**/markdown-navigator.xml
.idea/**/markdown-navigator-enh.xml
.idea/**/markdown-navigator/

# Cache file creation bug
# See https://youtrack.jetbrains.com/issue/JBR-2257
.idea/$CACHE_FILE$

# CodeStream plugin
# https://plugins.jetbrains.com/plugin/12206-codestream
.idea/codestream.xml

# Azure Toolkit for IntelliJ plugin
# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
.idea/**/azureSettings.xml

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -168,41 +206,85 @@ htmlcov/
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook

# IPython

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
.python-version

# poetry
.venv

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject
Expand All @@ -226,18 +308,41 @@ dmypy.json
# Pyre type checker
.pyre/

# Plugins
.secrets.baseline
# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

### VisualStudioCode ###
.vscode/*
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets

# Local History for Visual Studio Code
.history/

# Built Visual Studio Code Extensions
*.vsix

### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide

# Support for Project snippet scope
.vscode/*.code-snippets

# Ignore code-workspaces
*.code-workspace

### Windows ###
# Windows thumbnail cache files
Expand Down Expand Up @@ -269,7 +374,7 @@ $RECYCLE.BIN/
## Ignore Visual Studio temporary files, build results, and
## files generated by popular Visual Studio add-ons.
##
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore

# User-specific files
*.rsuser
Expand All @@ -291,12 +396,14 @@ mono_crash.*
[Rr]eleases/
x64/
x86/
[Ww][Ii][Nn]32/
[Aa][Rr][Mm]/
[Aa][Rr][Mm]64/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
[Ll]ogs/

# Visual Studio 2015/2017 cache/options directory
.vs/
Expand Down Expand Up @@ -328,6 +435,9 @@ project.lock.json
project.fragment.lock.json
artifacts/

# ASP.NET Scaffolding
ScaffoldingReadMe.txt

# StyleCop
StyleCopReport.xml

Expand All @@ -336,6 +446,7 @@ StyleCopReport.xml
*_p.c
*_h.h
*.ilk
*.meta
*.obj
*.iobj
*.pch
Expand All @@ -351,7 +462,7 @@ StyleCopReport.xml
*.tmp
*.tmp_proj
*_wpftmp.csproj
*.log
*.tlog
*.vspscc
*.vssscc
.builds
Expand Down Expand Up @@ -406,6 +517,11 @@ _TeamCity*
.axoCover/*
!.axoCover/settings.json

# Coverlet is a free, cross platform Code Coverage Tool
coverage*.json
coverage*.xml
coverage*.info

# Visual Studio code coverage results
*.coverage
*.coveragexml
Expand Down Expand Up @@ -553,6 +669,15 @@ node_modules/
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
*.vbw

# Visual Studio 6 auto-generated project file (contains which files were open etc.)
*.vbp

# Visual Studio 6 workspace and project file (working project files containing files to include in project)
*.dsw
*.dsp

# Visual Studio 6 technical files

# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
Expand Down Expand Up @@ -608,12 +733,34 @@ ASALocalRun/
# Local History for Visual Studio
.localhistory/

# Visual Studio History (VSHistory) files
.vshistory/

# BeatPulse healthcheck temp database
healthchecksdb

# Backup folder for Package Reference Convert tool in Visual Studio 2017
MigrationBackup/

# End of https://www.gitignore.io/api/osx,python,pycharm,windows,visualstudio,visualstudiocode
# Ionide (cross platform F# VS Code tools) working folder
.ionide/

# Fody - auto-generated XML schema
FodyWeavers.xsd

# VS Code files for those working on multiple tools

# Local History for Visual Studio Code

# Windows Installer files from build outputs

# JetBrains Rider
*.sln.iml

### VisualStudio Patch ###
# Additional files built by Visual Studio

# End of https://www.toptal.com/developers/gitignore/api/osx,python,pycharm,windows,visualstudio,visualstudiocode,jupyternotebooks


/examples
/dev_examples
4 changes: 2 additions & 2 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ members of the project's leadership.
## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
available at <https://www.contributor-covenant.org/version/1/4/code-of-conduct.html>

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
<https://www.contributor-covenant.org/faq>
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ help:
@echo "check-safety : run safety checks on all tests."
@echo "lint : run linting on all files (check-style + check-safety)"
@echo "test : run all tests."
@echo "test-cpu : run all tests that do not depend on Torch GPU support."
@echo "fast-test : run all quick tests."
@echo "codecov : check coverage of all the code."
@echo "build-docs : build sphinx documentation."
Expand Down Expand Up @@ -103,6 +104,10 @@ lint: check-style check-safety
test:
poetry run pytest -c pyproject.toml -v

.PHONY: test-cpu
test-cpu:
poetry run pytest -c pyproject.toml -v -m "not require_cuda_gpu"

.PHONY: fast-test
fast-test:
poetry run pytest -c pyproject.toml -v -m "not slow"
Expand Down
Loading