Skip to content

Commit

Permalink
Fix for empty pd.DataFrame in WebQSPDataset (#9665)
Browse files Browse the repository at this point in the history
this last minute commit broke things:
e331dcd

the changes to `, columns=...` is for
```
Traceback (most recent call last):
  File "/opt/pyg/pytorch_geometric/examples/llm/g_retriever.py", line 262, in <module>
    train(
  File "/opt/pyg/pytorch_geometric/examples/llm/g_retriever.py", line 132, in train
    train_dataset = WebQSPDataset(path, split='train')
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/datasets/web_qsp_dataset.py", line 145, in __init__
    super().__init__(root, force_reload=force_reload)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/data/in_memory_dataset.py", line 81, in __init__
    super().__init__(root, transform, pre_transform, pre_filter, log,
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/data/dataset.py", line 115, in __init__
    self._process()
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/data/dataset.py", line 262, in _process
    self.process()
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/datasets/web_qsp_dataset.py", line 202, in process
    nodes.node_attr = nodes.node_attr.fillna("")
  File "/usr/local/lib/python3.10/dist-packages/pandas/core/generic.py", line 6299, in __getattr__
    return object.__getattribute__(self, name)
AttributeError: 'DataFrame' object has no attribute 'node_attr'

```
then the LongTensor change is for


```
    return scatter(edge_attr, col, 0, num_nodes, fill_value)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/utils/_scatter.py", line 79, in scatter
    count.scatter_add_(0, index, src.new_ones(src.size(dim)))
RuntimeError: scatter(): Expected dtype int64 for index
```

this is the handle the edgecase where the Question Answer pair has no
knowledge graph associated

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people committed Sep 17, 2024
1 parent 7e3a24d commit 642f831
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `WebQSDataset.process` raising exceptions ([#9665](https://github.com/pyg-team/pytorch_geometric/pull/9665))

### Removed

## \[2.6.0\] - 2024-09-13
Expand Down
8 changes: 5 additions & 3 deletions torch_geometric/datasets/web_qsp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ def process(self) -> None:
nodes = pd.DataFrame([{
"node_id": v,
"node_attr": k,
} for k, v in raw_nodes.items()])
edges = pd.DataFrame(raw_edges)
} for k, v in raw_nodes.items()],
columns=["node_id", "node_attr"])
edges = pd.DataFrame(raw_edges,
columns=["src", "edge_attr", "dst"])

nodes.node_attr = nodes.node_attr.fillna("")
x = model.encode(
Expand All @@ -213,7 +215,7 @@ def process(self) -> None:
edge_index = torch.tensor([
edges.src.tolist(),
edges.dst.tolist(),
])
], dtype=torch.long)

question = f"Question: {example['question']}\nAnswer: "
label = ('|').join(example['answer']).lower()
Expand Down

0 comments on commit 642f831

Please sign in to comment.