From ff1c6e5af7b9abc4bc704bd5588e214ca41026ce Mon Sep 17 00:00:00 2001 From: jfdreis Date: Wed, 26 Mar 2025 17:25:47 +0000 Subject: [PATCH 1/4] Measuring time and communication --- nilai-api/src/nilai_api/handlers/nilrag.py | 53 +++++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index 83ae7a8b..d7054e03 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -1,5 +1,7 @@ import logging import numpy as np +import time +import sys import nilql import nilrag @@ -20,6 +22,8 @@ "sentence-transformers/all-MiniLM-L6-v2", device="cpu" ) # FIXME: Use a GPU model and move to a separate container +def get_size_in_MB(obj): + return sys.getsizeof(obj) / (1024 * 1024) def generate_embeddings_huggingface( chunks_or_query: Union[str, list], @@ -73,14 +77,18 @@ def handle_nilrag(req: ChatRequest): nilDB = nilrag.NilDB(nodes) # Initialize secret keys + start_time = time.time() num_parties = len(nilDB.nodes) additive_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"sum": True}) xor_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"store": True}) + end_time = time.time() + print(f"Initialization of secret key took {end_time - start_time:.2f} seconds") # Step 2: Secret share query logger.debug("Secret sharing query and sending to NilDB...") # 2.1 Extract the user query query = None + start_time = time.time() for message in req.messages: if message.role == "user": query = message.content @@ -88,28 +96,49 @@ def handle_nilrag(req: ChatRequest): if query is None: raise HTTPException(status_code=400, detail="No user query found") + end_time = time.time() + print(f"Time to extract user query {end_time - start_time:.2f} seconds") # 2.2 Generate query embeddings: one string query is assumed. + start_time = time.time() query_embedding = generate_embeddings_huggingface([query])[0] nilql_query_embedding = encrypt_float_list(additive_key, query_embedding) + end_time = time.time() + print(f"Time to generate query embedding {end_time - start_time:.2f} seconds") + query_size = get_size_in_MB(nilql_query_embedding) + print(f"Size of secret-shared query sent to NilDB: {query_size:.3f} MB") + # Step 3: Ask NilDB to compute the differences logger.debug("Requesting computation from NilDB...") + start_time = time.time() difference_shares = nilDB.diff_query_execute(nilql_query_embedding) + end_time = time.time() + print(f"Time to ask nilDB to compute the differences {end_time - start_time:.2f} seconds") + diff_shares_size = get_size_in_MB(difference_shares) + print(f"Size of difference shares received: {diff_shares_size:.3f} MB") + # Step 4: Compute distances and sort logger.debug("Compute distances and sort...") # 4.1 Group difference shares by ID + start_time = time.time() difference_shares_by_id = group_shares_by_id( difference_shares, # type: ignore lambda share: share["difference"], ) + end_time = time.time() + print(f"Time to Group difference shares by ID {end_time - start_time:.2f} seconds") # 4.2 Transpose the lists for each _id + start_time = time.time() difference_shares_by_id = { id: np.array(differences).T.tolist() for id, differences in difference_shares_by_id.items() } + end_time = time.time() + print(f"Time to Transpose the lists for each _id {end_time - start_time:.2f} seconds") # 4.3 Decrypt and compute distances + start_time = time.time() reconstructed = [ { "_id": id, @@ -119,36 +148,55 @@ def handle_nilrag(req: ChatRequest): } for id, difference_shares in difference_shares_by_id.items() ] + end_time = time.time() + print(f"Time to Decrypt and compute distances {end_time - start_time:.2f} seconds") # 4.4 Sort id list based on the corresponding distances + start_time = time.time() sorted_ids = sorted(reconstructed, key=lambda x: x["distances"]) - + end_time = time.time() + print(f"Time to Sort id list based on the corresponding distances {end_time - start_time:.2f} seconds") # Step 5: Query the top k logger.debug("Query top k chunks...") top_k = 2 top_k_ids = [item["_id"] for item in sorted_ids[:top_k]] # 5.1 Query top k + start_time = time.time() chunk_shares = nilDB.chunk_query_execute(top_k_ids) + end_time = time.time() + print(f"Time to Query top k {end_time - start_time:.2f} seconds") + chunk_shares_size = get_size_in_MB(chunk_shares) + print(f"Size of chunk shares received: {chunk_shares_size:.3f} MB") # 5.2 Group chunk shares by ID + start_time = time.time() chunk_shares_by_id = group_shares_by_id( chunk_shares, # type: ignore lambda share: share["chunk"], ) + end_time = time.time() + print(f"Time to Group chunk shares by ID {end_time - start_time:.2f} seconds") # 5.3 Decrypt chunks + start_time = time.time() top_results = [ {"_id": id, "distances": nilql.decrypt(xor_key, chunk_shares)} for id, chunk_shares in chunk_shares_by_id.items() ] + end_time = time.time() + print(f"Time to decrypt chunk {end_time - start_time:.2f} seconds") # Step 6: Format top results + start_time = time.time() formatted_results = "\n".join( f"- {str(result['distances'])}" for result in top_results ) relevant_context = f"\n\nRelevant Context:\n{formatted_results}" + end_time = time.time() + print(f"Time to format top resuls {end_time - start_time:.2f} seconds") # Step 7: Update system message + start_time = time.time() for message in req.messages: if message.role == "system": if message.content is None: @@ -163,8 +211,9 @@ def handle_nilrag(req: ChatRequest): else: # If no system message exists, add one req.messages.insert(0, Message(role="system", content=relevant_context)) - + end_time = time.time() logger.debug(f"System message updated with relevant context:\n {req.messages}") + print(f"Time to update system message {end_time - start_time:.2f} seconds") except Exception as e: logger.error("An error occurred within nilrag: %s", str(e)) From 1bd9a20435fc7c9bda5eadf39f8d2fcfe71363cc Mon Sep 17 00:00:00 2001 From: jfdreis Date: Thu, 27 Mar 2025 13:10:17 +0000 Subject: [PATCH 2/4] feat: add flag for communication measurements --- nilai-api/src/nilai_api/handlers/nilrag.py | 81 +++++++++++++--------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index d7054e03..fabededf 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -2,6 +2,7 @@ import numpy as np import time import sys +import os import nilql import nilrag @@ -22,6 +23,9 @@ "sentence-transformers/all-MiniLM-L6-v2", device="cpu" ) # FIXME: Use a GPU model and move to a separate container +#Retrieve the ENABLE_MEASUREMENTS flag from environment variable +ENABLE_MEASUREMENTS = os.getenv("ENABLE_MEASUREMENTS", "0") in ["1", "True"] + def get_size_in_MB(obj): return sys.getsizeof(obj) / (1024 * 1024) @@ -81,8 +85,9 @@ def handle_nilrag(req: ChatRequest): num_parties = len(nilDB.nodes) additive_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"sum": True}) xor_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"store": True}) - end_time = time.time() - print(f"Initialization of secret key took {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Initialization of secret keys took {end_time - start_time:.2f} seconds") # Step 2: Secret share query logger.debug("Secret sharing query and sending to NilDB...") @@ -96,27 +101,30 @@ def handle_nilrag(req: ChatRequest): if query is None: raise HTTPException(status_code=400, detail="No user query found") - end_time = time.time() - print(f"Time to extract user query {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to extract user query {end_time - start_time:.2f} seconds") # 2.2 Generate query embeddings: one string query is assumed. start_time = time.time() query_embedding = generate_embeddings_huggingface([query])[0] nilql_query_embedding = encrypt_float_list(additive_key, query_embedding) - end_time = time.time() - print(f"Time to generate query embedding {end_time - start_time:.2f} seconds") - query_size = get_size_in_MB(nilql_query_embedding) - print(f"Size of secret-shared query sent to NilDB: {query_size:.3f} MB") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to generate query embedding {end_time - start_time:.2f} seconds") + query_size = get_size_in_MB(nilql_query_embedding) + print(f"Size of secret-shared query sent to NilDB: {query_size:.3f} MB") # Step 3: Ask NilDB to compute the differences logger.debug("Requesting computation from NilDB...") start_time = time.time() difference_shares = nilDB.diff_query_execute(nilql_query_embedding) - end_time = time.time() - print(f"Time to ask nilDB to compute the differences {end_time - start_time:.2f} seconds") - diff_shares_size = get_size_in_MB(difference_shares) - print(f"Size of difference shares received: {diff_shares_size:.3f} MB") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to ask nilDB to compute the differences {end_time - start_time:.2f} seconds") + diff_shares_size = get_size_in_MB(difference_shares) + print(f"Size of difference shares received: {diff_shares_size:.3f} MB") # Step 4: Compute distances and sort @@ -127,16 +135,18 @@ def handle_nilrag(req: ChatRequest): difference_shares, # type: ignore lambda share: share["difference"], ) - end_time = time.time() - print(f"Time to Group difference shares by ID {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to Group difference shares by ID {end_time - start_time:.2f} seconds") # 4.2 Transpose the lists for each _id start_time = time.time() difference_shares_by_id = { id: np.array(differences).T.tolist() for id, differences in difference_shares_by_id.items() } - end_time = time.time() - print(f"Time to Transpose the lists for each _id {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to Transpose the lists for each _id {end_time - start_time:.2f} seconds") # 4.3 Decrypt and compute distances start_time = time.time() reconstructed = [ @@ -148,13 +158,15 @@ def handle_nilrag(req: ChatRequest): } for id, difference_shares in difference_shares_by_id.items() ] - end_time = time.time() - print(f"Time to Decrypt and compute distances {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to Decrypt and compute distances {end_time - start_time:.2f} seconds") # 4.4 Sort id list based on the corresponding distances start_time = time.time() sorted_ids = sorted(reconstructed, key=lambda x: x["distances"]) - end_time = time.time() - print(f"Time to Sort id list based on the corresponding distances {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to Sort id list based on the corresponding distances {end_time - start_time:.2f} seconds") # Step 5: Query the top k logger.debug("Query top k chunks...") top_k = 2 @@ -163,10 +175,11 @@ def handle_nilrag(req: ChatRequest): # 5.1 Query top k start_time = time.time() chunk_shares = nilDB.chunk_query_execute(top_k_ids) - end_time = time.time() - print(f"Time to Query top k {end_time - start_time:.2f} seconds") - chunk_shares_size = get_size_in_MB(chunk_shares) - print(f"Size of chunk shares received: {chunk_shares_size:.3f} MB") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to Query {top_k} chunks {end_time - start_time:.2f} seconds") + chunk_shares_size = get_size_in_MB(chunk_shares) + print(f"Size of chunk shares received: {chunk_shares_size:.3f} MB") # 5.2 Group chunk shares by ID start_time = time.time() @@ -174,8 +187,9 @@ def handle_nilrag(req: ChatRequest): chunk_shares, # type: ignore lambda share: share["chunk"], ) - end_time = time.time() - print(f"Time to Group chunk shares by ID {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to Group chunk shares by ID {end_time - start_time:.2f} seconds") # 5.3 Decrypt chunks start_time = time.time() @@ -183,8 +197,9 @@ def handle_nilrag(req: ChatRequest): {"_id": id, "distances": nilql.decrypt(xor_key, chunk_shares)} for id, chunk_shares in chunk_shares_by_id.items() ] - end_time = time.time() - print(f"Time to decrypt chunk {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to decrypt chunk {end_time - start_time:.2f} seconds") # Step 6: Format top results start_time = time.time() @@ -192,8 +207,9 @@ def handle_nilrag(req: ChatRequest): f"- {str(result['distances'])}" for result in top_results ) relevant_context = f"\n\nRelevant Context:\n{formatted_results}" - end_time = time.time() - print(f"Time to format top resuls {end_time - start_time:.2f} seconds") + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to format top resuls {end_time - start_time:.2f} seconds") # Step 7: Update system message start_time = time.time() @@ -211,9 +227,10 @@ def handle_nilrag(req: ChatRequest): else: # If no system message exists, add one req.messages.insert(0, Message(role="system", content=relevant_context)) - end_time = time.time() + if ENABLE_MEASUREMENTS: + end_time = time.time() + print(f"Time to update system message {end_time - start_time:.2f} seconds") logger.debug(f"System message updated with relevant context:\n {req.messages}") - print(f"Time to update system message {end_time - start_time:.2f} seconds") except Exception as e: logger.error("An error occurred within nilrag: %s", str(e)) From 502467e6823185d271042740093348edccca0b59 Mon Sep 17 00:00:00 2001 From: jfdreis Date: Fri, 28 Mar 2025 16:09:41 +0000 Subject: [PATCH 3/4] feat: adding measurements results in the response --- nilai-api/src/nilai_api/handlers/nilrag.py | 103 +++++++++--------- nilai-api/src/nilai_api/routers/private.py | 7 +- .../src/nilai_common/api_model.py | 16 +++ 3 files changed, 71 insertions(+), 55 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index fabededf..1e2c1395 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -2,7 +2,6 @@ import numpy as np import time import sys -import os import nilql import nilrag @@ -23,9 +22,6 @@ "sentence-transformers/all-MiniLM-L6-v2", device="cpu" ) # FIXME: Use a GPU model and move to a separate container -#Retrieve the ENABLE_MEASUREMENTS flag from environment variable -ENABLE_MEASUREMENTS = os.getenv("ENABLE_MEASUREMENTS", "0") in ["1", "True"] - def get_size_in_MB(obj): return sys.getsizeof(obj) / (1024 * 1024) @@ -85,9 +81,8 @@ def handle_nilrag(req: ChatRequest): num_parties = len(nilDB.nodes) additive_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"sum": True}) xor_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"store": True}) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Initialization of secret keys took {end_time - start_time:.2f} seconds") + end_time = time.time() + secret_keys_initialization_time = end_time - start_time # Step 2: Secret share query logger.debug("Secret sharing query and sending to NilDB...") @@ -101,31 +96,24 @@ def handle_nilrag(req: ChatRequest): if query is None: raise HTTPException(status_code=400, detail="No user query found") - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to extract user query {end_time - start_time:.2f} seconds") + end_time = time.time() + extract_user_query_time = end_time - start_time # 2.2 Generate query embeddings: one string query is assumed. start_time = time.time() query_embedding = generate_embeddings_huggingface([query])[0] nilql_query_embedding = encrypt_float_list(additive_key, query_embedding) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to generate query embedding {end_time - start_time:.2f} seconds") - query_size = get_size_in_MB(nilql_query_embedding) - print(f"Size of secret-shared query sent to NilDB: {query_size:.3f} MB") - + end_time = time.time() + embedding_generation_time = end_time - start_time + query_size = get_size_in_MB(nilql_query_embedding) # Step 3: Ask NilDB to compute the differences logger.debug("Requesting computation from NilDB...") start_time = time.time() difference_shares = nilDB.diff_query_execute(nilql_query_embedding) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to ask nilDB to compute the differences {end_time - start_time:.2f} seconds") - diff_shares_size = get_size_in_MB(difference_shares) - print(f"Size of difference shares received: {diff_shares_size:.3f} MB") - + end_time = time.time() + asking_nilDB_time = end_time - start_time + difference_shares_size = get_size_in_MB(difference_shares) # Step 4: Compute distances and sort logger.debug("Compute distances and sort...") @@ -135,18 +123,16 @@ def handle_nilrag(req: ChatRequest): difference_shares, # type: ignore lambda share: share["difference"], ) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to Group difference shares by ID {end_time - start_time:.2f} seconds") + end_time = time.time() + group_shares_by_id_time = end_time - start_time # 4.2 Transpose the lists for each _id start_time = time.time() difference_shares_by_id = { id: np.array(differences).T.tolist() for id, differences in difference_shares_by_id.items() } - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to Transpose the lists for each _id {end_time - start_time:.2f} seconds") + end_time = time.time() + transpose_lists_time = end_time - start_time # 4.3 Decrypt and compute distances start_time = time.time() reconstructed = [ @@ -158,15 +144,15 @@ def handle_nilrag(req: ChatRequest): } for id, difference_shares in difference_shares_by_id.items() ] - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to Decrypt and compute distances {end_time - start_time:.2f} seconds") + end_time = time.time() + decryption_time = end_time - start_time + # 4.4 Sort id list based on the corresponding distances start_time = time.time() sorted_ids = sorted(reconstructed, key=lambda x: x["distances"]) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to Sort id list based on the corresponding distances {end_time - start_time:.2f} seconds") + end_time = time.time() + sort_id_list_time = end_time - start_time + # Step 5: Query the top k logger.debug("Query top k chunks...") top_k = 2 @@ -175,21 +161,17 @@ def handle_nilrag(req: ChatRequest): # 5.1 Query top k start_time = time.time() chunk_shares = nilDB.chunk_query_execute(top_k_ids) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to Query {top_k} chunks {end_time - start_time:.2f} seconds") - chunk_shares_size = get_size_in_MB(chunk_shares) - print(f"Size of chunk shares received: {chunk_shares_size:.3f} MB") - + end_time = time.time() + query_top_chunks_time = end_time - start_time + chunks_shares_size = get_size_in_MB(chunk_shares) # 5.2 Group chunk shares by ID start_time = time.time() chunk_shares_by_id = group_shares_by_id( chunk_shares, # type: ignore lambda share: share["chunk"], ) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to Group chunk shares by ID {end_time - start_time:.2f} seconds") + end_time = time.time() + group_chunks_time = end_time - start_time # 5.3 Decrypt chunks start_time = time.time() @@ -197,9 +179,8 @@ def handle_nilrag(req: ChatRequest): {"_id": id, "distances": nilql.decrypt(xor_key, chunk_shares)} for id, chunk_shares in chunk_shares_by_id.items() ] - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to decrypt chunk {end_time - start_time:.2f} seconds") + end_time = time.time() + decrypt_chunks_time = end_time - start_time # Step 6: Format top results start_time = time.time() @@ -207,9 +188,8 @@ def handle_nilrag(req: ChatRequest): f"- {str(result['distances'])}" for result in top_results ) relevant_context = f"\n\nRelevant Context:\n{formatted_results}" - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to format top resuls {end_time - start_time:.2f} seconds") + end_time = time.time() + format_results_time = end_time - start_time # Step 7: Update system message start_time = time.time() @@ -227,13 +207,30 @@ def handle_nilrag(req: ChatRequest): else: # If no system message exists, add one req.messages.insert(0, Message(role="system", content=relevant_context)) - if ENABLE_MEASUREMENTS: - end_time = time.time() - print(f"Time to update system message {end_time - start_time:.2f} seconds") + end_time = time.time() + update_system_message_time = end_time - start_time logger.debug(f"System message updated with relevant context:\n {req.messages}") + return { + "secret_keys_initialization_time": secret_keys_initialization_time, + "extract_user_query_time": extract_user_query_time, + "embedding_generation_time": embedding_generation_time, + "query_size": query_size, + "asking_nilDB_time": asking_nilDB_time, + "group_shares_by_id_time": group_shares_by_id_time, + "transpose_lists_time": transpose_lists_time, + "decryption_time": decryption_time, + "sort_id_list_time": sort_id_list_time, + "query_top_chunks_time": query_top_chunks_time, + "group_chunks_time": group_chunks_time, + "decrypt_chunks_time": decrypt_chunks_time, + "format_results_time": format_results_time, + "update_system_message_time": update_system_message_time, + "difference_shares_size": difference_shares_size, + "chunks_shares_size": chunks_shares_size, + } except Exception as e: logger.error("An error occurred within nilrag: %s", str(e)) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) + ) \ No newline at end of file diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 5c2c9cd8..0bf703a0 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -195,8 +195,10 @@ async def chat_completion( f"Chat completion request for model {model_name} from user {user.userid} on url: {model_url}" ) + nilrag_metrics = {} if req.nilrag: - handle_nilrag(req) + nilrag_metrics = handle_nilrag(req) + logger.info(f"NilRag metrics: {nilrag_metrics}") if req.stream: client = AsyncOpenAI(base_url=model_url, api_key="") @@ -261,6 +263,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: model_response = SignedChatCompletion( **response.model_dump(), signature="", + **nilrag_metrics, ) if model_response.usage is None: raise HTTPException( @@ -286,4 +289,4 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: signature = sign_message(state.private_key, response_json) model_response.signature = b64encode(signature).decode() - return model_response + return model_response \ No newline at end of file diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 8576d931..029f6a5a 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -42,6 +42,22 @@ class ChatRequest(BaseModel): class SignedChatCompletion(ChatCompletion): signature: str + secret_keys_initialization_time: Optional[float] = None + extract_user_query_time: Optional[float] = None + embedding_generation_time: Optional[float] = None + asking_nilDB_time: Optional[float] = None + group_shares_by_id_time: Optional[float] = None + transpose_lists_time: Optional[float] = None + decryption_time: Optional[float] = None + sort_id_list_time: Optional[float] = None + query_top_chunks_time: Optional[float] = None + group_chunks_time: Optional[float] = None + decrypt_chunks_time: Optional[float] = None + format_results_time: Optional[float] = None + update_system_message_time: Optional[float] = None + query_size: Optional[float] = None + chunks_shares_size: Optional[float] = None + difference_shares_size: Optional[float] = None class AttestationResponse(BaseModel): From 88d5b6b5ae19fe39732a12fe8e7aa5d2bdc82dba Mon Sep 17 00:00:00 2001 From: jfdreis Date: Tue, 1 Apr 2025 12:48:22 +0100 Subject: [PATCH 4/4] feat: changed metrics presentation in the response --- nilai-api/src/nilai_api/handlers/nilrag.py | 67 ++++++++++--------- nilai-api/src/nilai_api/routers/private.py | 4 +- .../src/nilai_common/api_model.py | 19 +----- 3 files changed, 39 insertions(+), 51 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index 1e2c1395..ab5f5bcd 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -25,6 +25,9 @@ def get_size_in_MB(obj): return sys.getsizeof(obj) / (1024 * 1024) +def get_size_in_KB(obj): + return sys.getsizeof(obj) / 1024 + def generate_embeddings_huggingface( chunks_or_query: Union[str, list], ): @@ -82,7 +85,7 @@ def handle_nilrag(req: ChatRequest): additive_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"sum": True}) xor_key = nilql.secret_key({"nodes": [{}] * num_parties}, {"store": True}) end_time = time.time() - secret_keys_initialization_time = end_time - start_time + secret_keys_initialization_time = round(end_time - start_time, 2) # Step 2: Secret share query logger.debug("Secret sharing query and sending to NilDB...") @@ -97,23 +100,23 @@ def handle_nilrag(req: ChatRequest): if query is None: raise HTTPException(status_code=400, detail="No user query found") end_time = time.time() - extract_user_query_time = end_time - start_time + extract_user_query_time = round(end_time - start_time, 2) # 2.2 Generate query embeddings: one string query is assumed. start_time = time.time() query_embedding = generate_embeddings_huggingface([query])[0] nilql_query_embedding = encrypt_float_list(additive_key, query_embedding) end_time = time.time() - embedding_generation_time = end_time - start_time - query_size = get_size_in_MB(nilql_query_embedding) + embedding_generation_time = round(end_time - start_time, 2) + query_size = round(get_size_in_KB(nilql_query_embedding),2) # Step 3: Ask NilDB to compute the differences logger.debug("Requesting computation from NilDB...") start_time = time.time() difference_shares = nilDB.diff_query_execute(nilql_query_embedding) end_time = time.time() - asking_nilDB_time = end_time - start_time - difference_shares_size = get_size_in_MB(difference_shares) + asking_nilDB_time = round(end_time - start_time, 2) + difference_shares_size = round(get_size_in_KB(difference_shares),2) # Step 4: Compute distances and sort logger.debug("Compute distances and sort...") @@ -124,7 +127,7 @@ def handle_nilrag(req: ChatRequest): lambda share: share["difference"], ) end_time = time.time() - group_shares_by_id_time = end_time - start_time + group_shares_by_id_time = round(end_time - start_time, 2) # 4.2 Transpose the lists for each _id start_time = time.time() difference_shares_by_id = { @@ -132,7 +135,7 @@ def handle_nilrag(req: ChatRequest): for id, differences in difference_shares_by_id.items() } end_time = time.time() - transpose_lists_time = end_time - start_time + transpose_lists_time = round(end_time - start_time, 2) # 4.3 Decrypt and compute distances start_time = time.time() reconstructed = [ @@ -145,13 +148,13 @@ def handle_nilrag(req: ChatRequest): for id, difference_shares in difference_shares_by_id.items() ] end_time = time.time() - decryption_time = end_time - start_time + decryption_time = round(end_time - start_time, 2) # 4.4 Sort id list based on the corresponding distances start_time = time.time() sorted_ids = sorted(reconstructed, key=lambda x: x["distances"]) end_time = time.time() - sort_id_list_time = end_time - start_time + sort_id_list_time = round(end_time - start_time, 2) # Step 5: Query the top k logger.debug("Query top k chunks...") @@ -162,8 +165,8 @@ def handle_nilrag(req: ChatRequest): start_time = time.time() chunk_shares = nilDB.chunk_query_execute(top_k_ids) end_time = time.time() - query_top_chunks_time = end_time - start_time - chunks_shares_size = get_size_in_MB(chunk_shares) + query_top_chunks_time = round(end_time - start_time, 2) + chunks_shares_size = round(get_size_in_KB(chunk_shares), 2) # 5.2 Group chunk shares by ID start_time = time.time() chunk_shares_by_id = group_shares_by_id( @@ -171,7 +174,7 @@ def handle_nilrag(req: ChatRequest): lambda share: share["chunk"], ) end_time = time.time() - group_chunks_time = end_time - start_time + group_chunks_time = round(end_time - start_time, 2) # 5.3 Decrypt chunks start_time = time.time() @@ -180,7 +183,7 @@ def handle_nilrag(req: ChatRequest): for id, chunk_shares in chunk_shares_by_id.items() ] end_time = time.time() - decrypt_chunks_time = end_time - start_time + decrypt_chunks_time = round(end_time - start_time, 2) # Step 6: Format top results start_time = time.time() @@ -189,7 +192,7 @@ def handle_nilrag(req: ChatRequest): ) relevant_context = f"\n\nRelevant Context:\n{formatted_results}" end_time = time.time() - format_results_time = end_time - start_time + format_results_time = round(end_time - start_time, 2) # Step 7: Update system message start_time = time.time() @@ -208,25 +211,25 @@ def handle_nilrag(req: ChatRequest): # If no system message exists, add one req.messages.insert(0, Message(role="system", content=relevant_context)) end_time = time.time() - update_system_message_time = end_time - start_time + update_system_message_time = round(end_time - start_time, 2) logger.debug(f"System message updated with relevant context:\n {req.messages}") return { - "secret_keys_initialization_time": secret_keys_initialization_time, - "extract_user_query_time": extract_user_query_time, - "embedding_generation_time": embedding_generation_time, - "query_size": query_size, - "asking_nilDB_time": asking_nilDB_time, - "group_shares_by_id_time": group_shares_by_id_time, - "transpose_lists_time": transpose_lists_time, - "decryption_time": decryption_time, - "sort_id_list_time": sort_id_list_time, - "query_top_chunks_time": query_top_chunks_time, - "group_chunks_time": group_chunks_time, - "decrypt_chunks_time": decrypt_chunks_time, - "format_results_time": format_results_time, - "update_system_message_time": update_system_message_time, - "difference_shares_size": difference_shares_size, - "chunks_shares_size": chunks_shares_size, + "secret_keys_initialization_seconds": secret_keys_initialization_time, + "extract_user_query_seconds": extract_user_query_time, + "embedding_generation_seconds": embedding_generation_time, + "asking_nilDB_seconds": asking_nilDB_time, + "group_shares_by_id_seconds": group_shares_by_id_time, + "transpose_lists_seconds": transpose_lists_time, + "decryption_seconds": decryption_time, + "sort_id_list_seconds": sort_id_list_time, + "query_top_chunks_seconds": query_top_chunks_time, + "group_chunks_seconds": group_chunks_time, + "decrypt_chunks_seconds": decrypt_chunks_time, + "format_results_seconds": format_results_time, + "update_system_message_seconds": update_system_message_time, + "query_size_kbs": query_size, + "difference_shares_size_kbs": difference_shares_size, + "chunks_shares_size_kbs": chunks_shares_size, } except Exception as e: diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 0bf703a0..39c24946 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -198,7 +198,7 @@ async def chat_completion( nilrag_metrics = {} if req.nilrag: nilrag_metrics = handle_nilrag(req) - logger.info(f"NilRag metrics: {nilrag_metrics}") + logger.info(f"nilRag metrics: {nilrag_metrics}") if req.stream: client = AsyncOpenAI(base_url=model_url, api_key="") @@ -263,7 +263,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: model_response = SignedChatCompletion( **response.model_dump(), signature="", - **nilrag_metrics, + metrics=nilrag_metrics, ) if model_response.usage is None: raise HTTPException( diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 029f6a5a..c34839a5 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -42,22 +42,7 @@ class ChatRequest(BaseModel): class SignedChatCompletion(ChatCompletion): signature: str - secret_keys_initialization_time: Optional[float] = None - extract_user_query_time: Optional[float] = None - embedding_generation_time: Optional[float] = None - asking_nilDB_time: Optional[float] = None - group_shares_by_id_time: Optional[float] = None - transpose_lists_time: Optional[float] = None - decryption_time: Optional[float] = None - sort_id_list_time: Optional[float] = None - query_top_chunks_time: Optional[float] = None - group_chunks_time: Optional[float] = None - decrypt_chunks_time: Optional[float] = None - format_results_time: Optional[float] = None - update_system_message_time: Optional[float] = None - query_size: Optional[float] = None - chunks_shares_size: Optional[float] = None - difference_shares_size: Optional[float] = None + metrics: Optional[dict] = {} class AttestationResponse(BaseModel): @@ -85,4 +70,4 @@ class ModelEndpoint(BaseModel): class HealthCheckResponse(BaseModel): status: str - uptime: str + uptime: str \ No newline at end of file