Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Find out if the cast to long for batch components in forward functions is necessary #6

Open
nfelnlp opened this issue May 3, 2021 · 1 comment
Assignees
Labels
question Further information is requested

Comments

@nfelnlp
Copy link
Member

nfelnlp commented May 3, 2021

https://github.com/nfelnlp/thermostat/blob/a8180a2d83e1c3ec5f873dbf0ce0ab14026cf6bf/src/thermostat/explain.py#L54

        def bert_forward(input_ids, attention_mask, token_type_ids):
            input_model = {
                'input_ids': input_ids.long(),
                'attention_mask': attention_mask.long(),
                'token_type_ids': token_type_ids.long(),
            }
            output_model = model(**input_model)[0]
            return output_model

        def roberta_forward(input_ids, attention_mask):
            input_model = {
                'input_ids': input_ids.long(),
                'attention_mask': attention_mask.long(),
            }
            output_model = model(**input_model)[0]
            return output_model
@nfelnlp nfelnlp added the question Further information is requested label May 3, 2021
@nfelnlp nfelnlp self-assigned this May 3, 2021
@rbtsbg
Copy link
Collaborator

rbtsbg commented May 31, 2021

We could use try - catch to automatically circumvent the cast when in a later version of captum the issue is fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants