Skip to content

Commit

Permalink
Merge branch 'dave90-issue_897_new' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Nov 13, 2024
2 parents c49a371 + f427795 commit 75bf942
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 0 deletions.
94 changes: 94 additions & 0 deletions core/cat/routes/memory/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,97 @@ async def get_points_in_collection(
"points": points,
"next_offset": next_offset
}



# EDIT a point in memory
@router.put("/collections/{collection_id}/points/{point_id}", response_model=MemoryPoint)
async def edit_memory_point(
request: Request,
collection_id: str,
point_id: str,
point: MemoryPointBase,
stray: StrayCat = Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.EDIT)),
) -> MemoryPoint:
"""Edit a point in memory
Example
----------
```
collection = "declarative"
content = "MIAO!"
metadata = {"custom_key": "custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# create a point
res = requests.post(
f"http://localhost:1865/memory/collections/{collection}/points", json=req_json
)
json = res.json()
#get the id
point_id = json["id"]
# new point values
content = "NEW MIAO!"
metadata = {"custom_key": "new_custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# edit the point
res = requests.put(
f"http://localhost:1865/memory/collections/{collection}/points/{point_id}", json=req_json
)
json = res.json()
print(json)
```
"""

# do not touch procedural memory
if collection_id == "procedural":
raise HTTPException(
status_code=400, detail={"error": "Procedural memory is read-only."}
)

vector_memory: VectorMemory = stray.memory.vectors
collections = list(vector_memory.collections.keys())
if collection_id not in collections:
raise HTTPException(
status_code=400, detail={"error": "Collection does not exist."}
)

#ensure point exist
points = vector_memory.collections[collection_id].get_points([point_id])
if points is None or len(points) == 0:
raise HTTPException(
status_code=400, detail={"error": "Point does not exist."}
)

# embed content
embedding = stray.embedder.embed_query(point.content)

# ensure source is set
if not point.metadata.get("source"):
point.metadata["source"] = (
stray.user_id
) # this will do also for declarative memory

# ensure when is set
if not point.metadata.get("when"):
point.metadata["when"] = time.time() #if when is not in the metadata set the current time

# edit point
qdrant_point = vector_memory.collections[collection_id].add_point(
content=point.content, vector=embedding, metadata=point.metadata, id=point_id
)

return MemoryPoint(
metadata=qdrant_point.payload["metadata"],
content=qdrant_point.payload["page_content"],
vector=qdrant_point.vector,
id=qdrant_point.id,
)
82 changes: 82 additions & 0 deletions core/tests/routes/memory/test_memory_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,85 @@ def test_get_collection_points_offset(client, patch_time_now, collection):
assert points_payloads == expected_payloads



def test_edit_point_wrong_collection_and_not_exist(client):

req_json = {
"content": "MIAO!"
}

point_id = 100

# wrong collection
res = client.put(
f"/memory/collections/wrongcollection/points/{point_id}", json=req_json
)
assert res.status_code == 400
assert "Collection does not exist" in res.json()["detail"]["error"]

# cannot write procedural point
res = client.put(
"/memory/collections/procedural/points/{point_id}", json=req_json
)
assert res.status_code == 400
assert "Procedural memory is read-only" in res.json()["detail"]["error"]

# point do not exist
res = client.put(
"/memory/collections/declarative/points/{point_id}", json=req_json
)
assert res.status_code == 400
assert "Point does not exist." in res.json()["detail"]["error"]



@pytest.mark.parametrize("collection", ["episodic", "declarative"])
def test_edit_memory_point(client, patch_time_now, collection):

# create a point
content = "MIAO!"
metadata = {"custom_key": "custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}
# create a point
res = client.post(
f"/memory/collections/{collection}/points", json=req_json
)
#get the id
assert res.status_code == 200
json = res.json()
assert json["id"]
point_id = json["id"]
# new point values
content = "NEW MIAO!"
metadata = {"custom_key": "new_custom_value"}
req_json = {
"content": content,
"metadata": metadata,
}

res = client.put(
f"/memory/collections/{collection}/points/{point_id}", json=req_json
)
# check response
assert res.status_code == 200
json = res.json()
assert json["content"] == content
expected_metadata = {"when":FAKE_TIMESTAMP,"source": "user", **metadata}
assert json["metadata"] == expected_metadata
assert "id" in json
assert "vector" in json
assert isinstance(json["vector"], list)
assert isinstance(json["vector"][0], float)

# check memory contents
params = {"text": "miao"}
response = client.get("/memory/recall/", params=params)
json = response.json()
assert response.status_code == 200
assert len(json["vectors"]["collections"][collection]) == 1
memory = json["vectors"]["collections"][collection][0]
assert memory["page_content"] == content
assert memory["metadata"] == expected_metadata

0 comments on commit 75bf942

Please sign in to comment.