Skip to content

Commit

Permalink
Add blob output
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardobl committed Apr 19, 2024
1 parent d97b0fc commit fab05f2
Showing 1 changed file with 90 additions and 41 deletions.
131 changes: 90 additions & 41 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ def __init__(self,cache_path, device=-1):
os.makedirs(self.cache_path)

def encode(self, sentences):

to_encode = []
to_encode_index=[]
out = []
out = []
for s in sentences:
hash = hashlib.sha256(s.encode()).hexdigest()
cache_file = self.cache_path+"/"+hash+".dat"
Expand Down Expand Up @@ -66,7 +65,7 @@ def split(self, text, chunk_size, overlap , marker, out):
out.append([chunk, marker])


def completePendingJob(rpcClient , act):
def completePendingJob(rpcClient , act, CACHE_PATH):
jobs=[]
jobs.extend(rpcClient.getPendingJobs(rpc_pb2.RpcGetPendingJobs(filterByRunOn="openagents\\/embeddings")).jobs)
if len(jobs)>0 : log(rpcClient, str(len(jobs))+" pending jobs")
Expand All @@ -83,6 +82,10 @@ def getParamValue(key,default=None):
max_tokens = int(getParamValue("max-tokens", "1024"))
overlap = int(getParamValue("overlap", "128"))
quantize = getParamValue("quantize", "true") == "true"

outputFormat = job.outputFormat



sentences = []
for jin in job.input:
Expand All @@ -100,43 +103,88 @@ def getParamValue(key,default=None):

sentences.append([data,marker])

# Split long sentences
sentences_chunks=[]
for sentence in sentences:
act.split(sentence[0], max_tokens, overlap, sentence[1], sentences_chunks)
sentences = sentences_chunks
##

log(rpcClient,"Create embeddings for "+str(len(sentences))+" excerpts. max_tokens="+str(max_tokens)+", overlap="+str(overlap), job.id)

t=time.time()
embeddings = act.encode([sentences[i][1]+": "+sentences[i][0] for i in range(len(sentences))])
if quantize:
embeddings = act.quantize(embeddings)

log(rpcClient,"Embeddings created in "+str(time.time()-t)+" seconds", job.id)

# TODO support multiple output type

output = []
for i in range(len(sentences)):

# encode
dtype = embeddings[i].dtype
shape = embeddings[i].shape
embeddings_bytes = embeddings[i].tobytes()
embeddings_b64 = base64.b64encode(embeddings_bytes).decode('utf-8')


output.append(
[sentences[i][0], embeddings_b64, str(dtype), shape]
)

# decode
# embeddings_bytes = base64.b64decode(embeddings_b64)
# embeddings = np.frombuffer(embeddings_bytes, dtype=dtype).reshape(shape)

rpcClient.completeJob(rpc_pb2.RpcJobOutput(jobId=job.id, output=json.dumps(output)))
blobDiskUrl = None
blobRefId = None
blobRefFileName = None
blobDiskId = None
blobCached = False
if outputFormat=="application/hyperblob":
blobRefId = str(max_tokens) + str(overlap) + str(quantize) + "".join([sentences[i][0] + ":" + sentences[i][1] for i in range(len(sentences))])
blobRefId = hashlib.sha256(blobRefId.encode()).hexdigest()
blobRefFileName = CACHE_PATH+"/"+blobRefId+".blob"
if os.path.exists(blobRefFileName):
with open(blobRefFileName, "r") as f:
blobDiskUrl = f.read()
blobCached = True
if not blobDiskUrl:
blobDiskUrl=rpcClient.createDisk(rpc_pb2.RpcCreateDiskRequest()).url
blobDiskId=rpcClient.openDisk(rpc_pb2.RpcOpenDiskRequest(url=blobDiskUrl)).diskId

if blobCached:
rpcClient.completeJob(rpc_pb2.RpcJobOutput(jobId=job.id, output=blobDiskUrl))
else:
# Split long sentences
sentences_chunks=[]
for sentence in sentences:
act.split(sentence[0], max_tokens, overlap, sentence[1], sentences_chunks)
sentences = sentences_chunks
##

log(rpcClient,"Create embeddings for "+str(len(sentences))+" excerpts. max_tokens="+str(max_tokens)+", overlap="+str(overlap), job.id)

t=time.time()
embeddings = act.encode([sentences[i][1]+": "+sentences[i][0] for i in range(len(sentences))])
if quantize:
embeddings = act.quantize(embeddings)

log(rpcClient,"Embeddings created in "+str(time.time()-t)+" seconds", job.id)


if blobDiskId:
# write on disk
for i in range(len(sentences)):
dtype = embeddings[i].dtype
shape = embeddings[i].shape
embeddings_bytes = embeddings[i].tobytes()
rpcClient.diskWriteSmallFile(rpc_pb2.RpcDiskWriteFileRequest(diskId=blobDiskId, path=str(i)+".embeddings.dtype", data=str(dtype).encode("utf-8")))
rpcClient.diskWriteSmallFile(rpc_pb2.RpcDiskWriteFileRequest(diskId=blobDiskId, path=str(i)+".embeddings.shape", data=json.dumps(shape).encode("utf-8")))

CHUNK_SIZE = 1024
def write_embeddings():
for j in range(0, len(embeddings_bytes), CHUNK_SIZE):
chunk = bytes(embeddings_bytes[j:j+CHUNK_SIZE])
request = rpc_pb2.RpcDiskWriteFileRequest(diskId=str(blobDiskId), path=str(i)+".embeddings.vectors", data=chunk)
yield request
rpcClient.diskWriteFile(write_embeddings())


sentences_bytes = sentences[i][0].encode("utf-8")
def write_sentences():
for j in range(0, len(sentences_bytes), CHUNK_SIZE):
chunk = bytes(sentences_bytes[j:j+CHUNK_SIZE])
request = rpc_pb2.RpcDiskWriteFileRequest(diskId=str(blobDiskId), path=str(i)+".embeddings", data=chunk)
yield request
rpcClient.diskWriteFile(write_sentences())

with open(blobRefFileName, "w") as f:
f.write(blobDiskUrl)

rpcClient.completeJob(rpc_pb2.RpcJobOutput(jobId=job.id, output=blobDiskUrl))

else:
output = []
for i in range(len(sentences)):
dtype = embeddings[i].dtype
shape = embeddings[i].shape
embeddings_bytes = embeddings[i].tobytes()
embeddings_b64 = base64.b64encode(embeddings_bytes).decode('utf-8')
output.append(
[sentences[i][0], embeddings_b64, str(dtype), shape]
)



rpcClient.completeJob(rpc_pb2.RpcJobOutput(jobId=job.id, output=json.dumps(output)))

except Exception as e:
log(rpcClient, "Error processing job "+ str(e), job.id if job else None)
Expand Down Expand Up @@ -212,6 +260,7 @@ def getParamValue(key,default=None):
"kind": {{meta.kind}},
"created_at": {{sys.timestamp_seconds}},
"tags": [
["output", "application/hyperblob"]
["param","run-on", "openagents/embeddings" ],
["param", "max-tokens", "{{in.max_tokens}}"],
["param", "overlap", "{{in.overlap}}"],
Expand Down Expand Up @@ -274,7 +323,7 @@ def main():
log(stub, "Error announcing node "+ str(e), None)

try:
completePendingJob(stub, t)
completePendingJob(stub, t, CACHE_PATH)
except Exception as e:
log(stub, "Error processing pending jobs "+ str(e), None)
time.sleep(100.0/1000.0)
Expand Down

0 comments on commit fab05f2

Please sign in to comment.