Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PUT Endpoint for Updating Single Point by ID #971

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions core/cat/routes/memory/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,97 @@ async def get_points_in_collection(
"points": points,
"next_offset": next_offset
}



# EDIT a point in memory
@router.put("/collections/{collection_id}/points/{point_id}", response_model=MemoryPoint)
async def edit_memory_point(
request: Request,
collection_id: str,
point_id: str,
point: MemoryPointBase,
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.EDIT)),
) -> MemoryPoint:
"""Edit a point in memory


Example
----------
```

collection = "declarative"
content = "MIAO!"
metadata = {"custom_key": "custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# create a point
res = requests.post(
f"http://localhost:1865/memory/collections/{collection}/points", json=req_json
)
json = res.json()
#get the id
point_id = json["id"]
# new point values
content = "NEW MIAO!"
metadata = {"custom_key": "new_custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}

# edit the point
res = requests.put(
f"http://localhost:1865/memory/collections/{collection}/points/{point_id}", json=req_json
)
json = res.json()
print(json)
```
"""

# do not touch procedural memory
if collection_id == "procedural":
raise HTTPException(
status_code=400, detail={"error": "Procedural memory is read-only."}
)

vector_memory: VectorMemory = stray.memory.vectors
collections = list(vector_memory.collections.keys())
if collection_id not in collections:
raise HTTPException(
status_code=400, detail={"error": "Collection does not exist."}
)

#ensure point exist
points = vector_memory.collections[collection_id].get_points([point_id])
if points is None or len(points) == 0:
raise HTTPException(
status_code=400, detail={"error": "Point does not exist."}
)

# embed content
embedding = stray.embedder.embed_query(point.content)

# ensure source is set
if not point.metadata.get("source"):
point.metadata["source"] = (
stray.user_id
) # this will do also for declarative memory

# ensure when is set
if not point.metadata.get("when"):
point.metadata["when"] = time.time() #if when is not in the metadata set the current time

# edit point
qdrant_point = vector_memory.collections[collection_id].add_point(
content=point.content, vector=embedding, metadata=point.metadata, id=point_id
)

return MemoryPoint(
metadata=qdrant_point.payload["metadata"],
content=qdrant_point.payload["page_content"],
vector=qdrant_point.vector,
id=qdrant_point.id,
)
82 changes: 82 additions & 0 deletions core/tests/routes/memory/test_memory_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,85 @@ def test_get_collection_points_offset(client, patch_time_now, collection):
assert points_payloads == expected_payloads



def test_edit_point_wrong_collection_and_not_exist(client):

req_json = {
"content": "MIAO!"
}

point_id = 100

# wrong collection
res = client.put(
f"/memory/collections/wrongcollection/points/{point_id}", json=req_json
)
assert res.status_code == 400
assert "Collection does not exist" in res.json()["detail"]["error"]

# cannot write procedural point
res = client.put(
"/memory/collections/procedural/points/{point_id}", json=req_json
)
assert res.status_code == 400
assert "Procedural memory is read-only" in res.json()["detail"]["error"]

# point do not exist
res = client.put(
"/memory/collections/declarative/points/{point_id}", json=req_json
)
assert res.status_code == 400
assert "Point does not exist." in res.json()["detail"]["error"]



@pytest.mark.parametrize("collection", ["episodic", "declarative"])
def test_edit_memory_point(client, patch_time_now, collection):

# create a point
content = "MIAO!"
metadata = {"custom_key": "custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# create a point
res = client.post(
f"/memory/collections/{collection}/points", json=req_json
)
#get the id
assert res.status_code == 200
json = res.json()
assert json["id"]
point_id = json["id"]
# new point values
content = "NEW MIAO!"
metadata = {"custom_key": "new_custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}

res = client.put(
f"/memory/collections/{collection}/points/{point_id}", json=req_json
)
# check response
assert res.status_code == 200
json = res.json()
assert json["content"] == content
expected_metadata = {"when":FAKE_TIMESTAMP,"source": "user", **metadata}
assert json["metadata"] == expected_metadata
assert "id" in json
assert "vector" in json
assert isinstance(json["vector"], list)
assert isinstance(json["vector"][0], float)

# check memory contents
params = {"text": "miao"}
response = client.get("/memory/recall/", params=params)
json = response.json()
assert response.status_code == 200
assert len(json["vectors"]["collections"][collection]) == 1
memory = json["vectors"]["collections"][collection][0]
assert memory["page_content"] == content
assert memory["metadata"] == expected_metadata
Loading