Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feat: Add W&B weave tracing #14262

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
21 changes: 21 additions & 0 deletions api/core/ops/entities/config_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class TracingProviderEnum(Enum):
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
WEAVE = "weave"


class BaseTracingConfig(BaseModel):
Expand Down Expand Up @@ -87,6 +88,26 @@ def url_validator(cls, v, info: ValidationInfo):

return v

class WeaveConfig(BaseTracingConfig):
"""
Model class for Weave tracing config.
"""

api_key: str
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"

@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://trace.wandb.ai"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")

return v


OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
8 changes: 8 additions & 0 deletions api/core/ops/ops_trace_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LangfuseConfig,
LangSmithConfig,
OpikConfig,
WeaveConfig,
TracingProviderEnum,
)
from core.ops.entities.trace_entity import (
Expand All @@ -33,6 +34,7 @@
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.weave_trace.weave_trace import WeaveDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data
from extensions.ext_database import db
Expand Down Expand Up @@ -60,6 +62,12 @@
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
},
TracingProviderEnum.WEAVE.value: {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint"],
"trace_instance": WeaveDataTrace,
},
}


Expand Down
Empty file.
Empty file.
89 changes: 89 additions & 0 deletions api/core/ops/weave_trace/entities/weave_trace_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from pydantic import BaseModel, Field, field_validator
from pydantic_core.core_schema import ValidationInfo
from typing import Any, Union, Optional, List, Dict

from core.ops.utils import replace_text_with_content

class WeaveTokenUsage(BaseModel):
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
total_tokens: Optional[int] = None

class WeaveMultiModel(BaseModel):
file_list: Optional[list[str]] = Field(None, description="List of files")



class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, Dict[str, Any], List, None]] = Field(None, description="Metadata and attributes associated with trace")
exception: Optional[str] = Field(None, description="Exception message of the trace")

@field_validator("inputs", "outputs")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
values = info.data
if v == {} or v is None:
return v
usage_metadata = {
"input_tokens": values.get("input_tokens", 0),
"output_tokens": values.get("output_tokens", 0),
"total_tokens": values.get("total_tokens", 0),
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": v,
}
elif field_name == "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
return data
else:
return {
"choices": {
"role": "ai" if field_name == "outputs" else "user",
"content": str(v),
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
if isinstance(v, dict):
v["usage_metadata"] = usage_metadata
v["file_list"] = file_list
return v
return v
187 changes: 187 additions & 0 deletions api/core/ops/weave_trace/weave_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Optional, cast

from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)

from core.ops.utils import filter_none_values, generate_dotted_order
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
import weave
import wandb
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel

logger = logging.getLogger(__name__)


class WeaveDataTrace(BaseTraceInstance):
def __init__(
self,
weave_config: WeaveConfig,
):
super().__init__(weave_config)
self.weave_api_key = weave_config.api_key
self.project_name = weave_config.project
self.entity = weave_config.entity
self.weave_client = weave.init(project_name=f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.calls = {}

def get_project_url(self,):
try:
project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
return project_url
except Exception as e:
logger.debug(f"Weave get run url failed: {str(e)}")
raise ValueError(f"Weave get run url failed: {str(e)}")


def trace(self, trace_info: BaseTraceInfo):
logger.debug(f"Trace info: {trace_info}")
print("Trace info: ", trace_info)
if isinstance(trace_info, WorkflowTraceInfo):
# self.workflow_trace(trace_info)
print("Workflow trace: ", trace_info)
pass
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
print("Moderation trace: ", trace_info)
pass
# self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
print("Suggested question trace: ", trace_info)
pass
# self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
print("Dataset retrieval trace: ", trace_info)
pass
# self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
print("Tool trace: ", trace_info)
pass
# self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
print("Generate name trace: ", trace_info)
pass
# self.generate_name_trace(trace_info)

def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
metadata = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id

user_id = message_data.from_account_id
metadata["user_id"] = user_id

if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id

metadata["message_id"] = message_id
metadata["start_time"]=trace_info.start_time
metadata["end_time"]=trace_info.end_time
metadata["tags"] = ["message", str(trace_info.conversation_mode)]
message_run = WeaveTraceModel(
id=message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
exception=trace_info.error,
file_list=file_list,
attributes=metadata
)
self.add_run(message_run)

# create llm run parented to message run
llm_run = WeaveTraceModel(
id=str(uuid.uuid4()),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
op="llm",
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=metadata,
)
self.add_run(llm_run, parent_run_id=message_id,)
self.update_run(llm_run)
self.update_run(message_run)

def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return

metadata = trace_info.metadata
metadata["tags"] = ["moderation"]
metadata["start_time"] = trace_info.start_time or trace_info.message_data.created_at,
metadata["end_time"] = trace_info.end_time or trace_info.message_data.updated_at,

moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.MODERATION_TRACE.value),
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
attributes=metadata,
)
self.add_run(moderation_run, parent_run_id=trace_info.message_id)

def api_check(self):
try:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status:
raise ValueError("Weave login failed")
else:
print("Weave login successful")
return True
except Exception as e:
logger.debug(f"Weave API check failed: {str(e)}")
raise ValueError(f"Weave API check failed: {str(e)}")

def add_run(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
self.calls[run_data.id] = call
if parent_run_id:
self.calls[run_data.id].parent_id = parent_run_id

def update_run(self, run_data: WeaveTraceModel):
call = self.calls.get(run_data.id)
if call:
self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
else:
raise ValueError(f"Call with id {run_data['id']} not found")
1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ tokenizers = "~0.15.0"
transformers = "~4.35.0"
unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
validators = "0.21.0"
weave = "~0.51.34"
yarl = "~1.18.3"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.

Expand Down
9 changes: 8 additions & 1 deletion api/services/ops_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str):
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})

if tracing_provider == "weave" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict()

Expand Down
Loading