Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): Library v2 Agents and Presets #9258

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
3d21a12
fix `library.db.update_agent_version_in_library(..)`
Pwuts Jan 13, 2025
d07be7f
fix other DB calls in `v2.library.db`
Pwuts Jan 13, 2025
577357d
feat(backend): Library v2 Agents and Presets
Pwuts Jan 13, 2025
700f2a3
Merge branch 'pwuts/open-2309-fix-db-error-could-not-find-field-at' i…
Pwuts Jan 13, 2025
85180e1
fix `update_agent_version_in_library`
Pwuts Jan 13, 2025
a8ac410
refactor: Fix type checking on Prisma statements in `v2.library.db`
Pwuts Jan 13, 2025
46473f5
fix creation and test cleanup of (library) agents
Pwuts Jan 16, 2025
bdbce58
Merge branch 'dev' into pwuts/open-2314-fix-up-and-re-introduce-libra…
Pwuts Feb 10, 2025
3bde649
Merge branch 'dev' into pwuts/open-2314-fix-up-and-re-introduce-libra…
Pwuts Feb 11, 2025
a1b8e60
restore template versions endpoint
Pwuts Feb 11, 2025
3a23d89
remove template stuff from API
Pwuts Feb 12, 2025
d0ced8b
Merge branch 'dev' into pwuts/open-2314-fix-up-and-re-introduce-libra…
Pwuts Feb 12, 2025
5c62bef
address feedback
Pwuts Feb 12, 2025
9f225e6
fix linting issue
Pwuts Feb 12, 2025
9f86483
fix weird stuff
Pwuts Feb 12, 2025
158e32d
shorten imports; remove unused imports
Pwuts Feb 12, 2025
1c9c927
fix `node_input` endpoint parameters
Pwuts Feb 12, 2025
eefa6bc
imports dx in `.library.model`
Pwuts Feb 12, 2025
97db355
fix execution without node input
Pwuts Feb 13, 2025
680ad03
Merge branch 'dev' into pwuts/open-2314-fix-up-and-re-introduce-libra…
Pwuts Feb 13, 2025
7b7a247
fix use of Json fields
Pwuts Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autogpt_platform/backend/backend/data/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def create_graph_execution(
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
Expand Down Expand Up @@ -168,6 +169,7 @@ async def create_graph_execution(
]
},
"userId": user_id,
"agentPresetId": preset_id,
},
include=GRAPH_EXECUTION_INCLUDE,
)
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False,
template: bool = False, # note: currently not in use; TODO: remove from DB entirely
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
Expand Down
8 changes: 5 additions & 3 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def add_execution(
data: BlockInput,
user_id: str,
graph_version: Optional[int] = None,
preset_id: str | None = None,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
Expand All @@ -824,9 +825,9 @@ def add_execution(

# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
name = node.input_default.get("name")
if name in data.get("node_input", {}):
input_data = {"value": data["node_input"][name]}
input_name = node.input_default.get("name")
if input_name and input_name in data:
input_data = {"value": data[input_name]}

# Extract webhook payload, and assign it to the input pin
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
Expand All @@ -851,6 +852,7 @@ def add_execution(
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)

starting_node_execs = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence
from typing import Annotated, Any, Dict, List, Optional, Sequence

from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException
from prisma.enums import AgentExecutionStatus, APIKeyPermission
from typing_extensions import TypedDict

Expand Down Expand Up @@ -101,7 +101,7 @@ def execute_graph_block(
def execute_graph(
graph_id: str,
graph_version: int,
node_input: dict[Any, Any],
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
api_key: APIKey = Depends(require_permission(APIKeyPermission.EXECUTE_GRAPH)),
) -> dict[str, Any]:
try:
Expand All @@ -113,7 +113,7 @@ def execute_graph(
)
return {"id": graph_exec.graph_exec_id}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)


Expand Down
19 changes: 16 additions & 3 deletions autogpt_platform/backend/backend/server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class ExecuteGraphResponse(pydantic.BaseModel):


class CreateGraph(pydantic.BaseModel):
template_id: str | None = None
template_version: int | None = None
graph: backend.data.graph.Graph | None = None
graph: backend.data.graph.Graph


class CreateAPIKeyRequest(pydantic.BaseModel):
Expand All @@ -57,5 +55,20 @@ class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: List[APIKeyPermission]


class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[2]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)


class RequestTopUp(pydantic.BaseModel):
credit_amount: int
62 changes: 60 additions & 2 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes
Expand Down Expand Up @@ -123,15 +125,15 @@ def run(self):
@staticmethod
async def test_execute_graph(
graph_id: str,
node_input: dict[str, Any],
user_id: str,
graph_version: Optional[int] = None,
node_input: Optional[dict[str, Any]] = None,
):
return backend.server.routers.v1.execute_graph(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
node_input=node_input,
node_input=node_input or {},
)

@staticmethod
Expand Down Expand Up @@ -170,8 +172,64 @@ async def test_get_graph_run_node_execution_results(

@staticmethod
async def test_delete_graph(graph_id: str, user_id: str):
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
graph_id=graph_id, user_id=user_id
)
return await backend.server.routers.v1.delete_graph(graph_id, user_id)

@staticmethod
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
return await backend.server.v2.library.routes.presets.get_presets(
user_id=user_id, page=page, page_size=page_size
)

@staticmethod
async def test_get_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.get_preset(
preset_id=preset_id, user_id=user_id
)

@staticmethod
async def test_create_preset(
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.create_preset(
preset=preset, user_id=user_id
)

@staticmethod
async def test_update_preset(
preset_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.update_preset(
preset_id=preset_id, preset=preset, user_id=user_id
)

@staticmethod
async def test_delete_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.delete_preset(
preset_id=preset_id, user_id=user_id
)

@staticmethod
async def test_execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
user_id: str,
node_input: Optional[dict[str, Any]] = None,
):
return await backend.server.v2.library.routes.presets.execute_preset(
graph_id=graph_id,
graph_version=graph_version,
preset_id=preset_id,
node_input=node_input or {},
user_id=user_id,
)

@staticmethod
async def test_create_store_listing(
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str
Expand Down
109 changes: 23 additions & 86 deletions autogpt_platform/backend/backend/server/routers/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from typing_extensions import Optional, TypedDict

import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db as library_db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import (
Expand Down Expand Up @@ -310,11 +311,6 @@ async def get_graph(
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(
path="/templates/{graph_id}/versions",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
Expand All @@ -330,41 +326,18 @@ async def get_graph_all_versions(
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)


async def do_create_graph(
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
) -> graph_db.GraphModel:
if create_graph.graph:
graph = graph_db.make_graph_model(create_graph.graph, user_id)
elif create_graph.template_id:
# Create a new graph from a template
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
raise HTTPException(
400, detail=f"Template #{create_graph.template_id} not found"
)
graph.version = 1
else:
raise HTTPException(
status_code=400, detail="Either graph or template_id must be provided."
)

graph.is_template = is_template
graph.is_active = not is_template
graph = graph_db.make_graph_model(create_graph.graph, user_id)
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)

graph = await graph_db.create_graph(graph, user_id=user_id)

# Create a library agent for the new graph
await library_db.create_library_agent(
graph.id,
graph.version,
user_id,
)

graph = await on_graph_activate(
graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
Expand All @@ -391,11 +364,6 @@ def get_credentials(credentials_id: str) -> "Credentials | None":
@v1_router.put(
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
@v1_router.put(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def update_graph(
graph_id: str,
graph: graph_db.Graph,
Expand Down Expand Up @@ -427,6 +395,10 @@ async def update_graph(
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)

if new_graph_version.is_active:
# Keep the library agent up to date with the new active version
await library_db.update_agent_version_in_library(
user_id, graph.id, graph.version
)

def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
Expand Down Expand Up @@ -483,6 +455,12 @@ def get_credentials(credentials_id: str) -> "Credentials | None":
version=new_active_version,
user_id=user_id,
)

# Keep the library agent up to date with the new active version
await library_db.update_agent_version_in_library(
user_id, new_active_graph.id, new_active_graph.version
)

if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(
Expand All @@ -498,7 +476,7 @@ def get_credentials(credentials_id: str) -> "Credentials | None":
)
def execute_graph(
graph_id: str,
node_input: dict[Any, Any],
node_input: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
user_id: Annotated[str, Depends(get_user_id)],
graph_version: Optional[int] = None,
) -> ExecuteGraphResponse:
Expand All @@ -508,7 +486,7 @@ def execute_graph(
)
return ExecuteGraphResponse(graph_exec_id=graph_exec.graph_exec_id)
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
msg = str(e).encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)


Expand Down Expand Up @@ -559,47 +537,6 @@ async def get_graph_run_node_execution_results(
return await execution_db.get_execution_results(graph_exec_id)


########################################################
##################### Templates ########################
########################################################


@v1_router.get(
path="/templates",
tags=["graphs", "templates"],
dependencies=[Depends(auth_middleware)],
)
async def get_templates(
user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="template", user_id=user_id)


@v1_router.get(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph


@v1_router.post(
path="/templates",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)


########################################################
##################### Schedules ########################
########################################################
Expand Down
Loading
Loading