-
Notifications
You must be signed in to change notification settings - Fork 0
/
lance_test.py
85 lines (52 loc) · 1.62 KB
/
lance_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import lancedb
import numpy as np
db = lancedb.connect("lancedb")
# # data = [{"vector": np.random.rand(128), "lat": 45.5, "long": -122.7},
# # {"vector": np.random.rand(128), "lat": 40.1, "long": -74.1}]
# # table = db.create_table("my_table", data, mode="overwrite")
# # db["my_table"].head()
# tbl = db.open_table("my_table")
# # Get the updated table as a pandas DataFrame
# df = tbl.to_pandas()
# # Print the DataFrame
# print(df)
# # df = tbl.search(np.random.rand(128)) \
# # .limit(2) \
# # .to_list()
# # df = tbl.search(np.zeros(128)) \
# # .where("""(
# # (lat IS 45.5)
# # """)
# #df = tbl.search(where="lat = 45.5")
# df = (tbl.search(np.zeros(128), vector_column_name="vector")
# .where("lat = 45.5", prefilter=True)
# .select(["lat", "long"])
# .limit(2)
# .to_pandas())
# print(df)
# --------------------------
import pandas as pd
import pyarrow as pa
# schema = pa.schema(
# [
# pa.field("vector", pa.list_(pa.float32(), 128)),
# pa.field("item", pa.string()),
# pa.field("price", pa.float32()),
# ])
# tbl = db.create_table("table5", schema=schema, mode="overwrite")
# data = [
# {"vector": np.random.rand(128), "item": "foo", "price": 10.0},
# {"vector": np.random.rand(128), "item": "bar", "price": 20.0},
# ]
# tbl.add(data=data)
tbl = db.open_table("table5")
# Get the updated table as a pandas DataFrame
df = tbl.to_pandas()
# Print the DataFrame
print(df)
df = (tbl.search(np.zeros(128), vector_column_name="vector")
.where("item = 'foo'", prefilter=True)
.select(["item", "price"])
.limit(2)
.to_pandas())
print(df)