Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any ways to support torch functions #546

Open
ioangatop opened this issue Jul 8, 2024 · 2 comments
Open

Any ways to support torch functions #546

ioangatop opened this issue Jul 8, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@ioangatop
Copy link
Contributor

ioangatop commented Jul 8, 2024

🚀 Feature request

Hi! I would like to use torch functions straight from the yaml file, for example:

class_path: SomeClass
init_args:
  process:
    class_path: torch.argmax
    init_args:
      dim: 1

However, I have had time to succeed, as the typing check fails, as for example the following:

class SomeClass:
  def __init__(self, process: Callable[..., torch.Tensor]) -> None:
    ...

The way around is to wrap them around callable classes, but it would be great to support them, like the dot imports for the torch optimizes, so I dont have to duplicate them or clear a wrapper class

@ioangatop ioangatop added the enhancement New feature or request label Jul 8, 2024
@mauvilsa
Copy link
Member

mauvilsa commented Jul 9, 2024

I think this would not be possible. If you run inspect.signature(torch.argmax) it just fails. And this wouldn't even be possible to fix in pytorch side. The problem is that torch.argmax has multiple signatures, see help(torch.argmax), which is something that native python functions don't support.

I will keep this in mind in case some better idea comes up. But now I think a wrapper class is the best option. Possibly a single class which gets the torch function name so that there is no need for one class for each function.

@ioangatop
Copy link
Contributor Author

ioangatop commented Jul 9, 2024

Thanks for the fast response!

Possibly a single class which gets the torch function name so that there is no need for one class for each function.

This is also what I did, but I also came up with a different kinda hacky idea to pass it as dict and parse it later as a partial function. For example:

class SomeClass:
  def __init__(self, process: Callable[..., torch.Tensor] | Dict[str, Any]) -> None:
    self.process = self.parse(process) if isinstance(process, dict) else process

  def parse(self, item):
    return functools.partial(
      jsonargparse._util.import_object(item["class_path"]), **item.get("init_args", {})
    )

if you have any other idea how to improve it, please do let me know 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants