Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Required test #346

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datapipe/meta/sql_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def get_existing_idx(self, idx: Optional[IndexDF] = None) -> IndexDF:
# Empty index -> empty result
return cast(
IndexDF,
pd.DataFrame(columns=[column.name for column in self.sql_schema]), # type: ignore
pd.DataFrame(columns=self.primary_keys), # type: ignore
)
idx_cols = list(set(idx.columns.tolist()) & set(self.primary_keys))
else:
Expand Down
126 changes: 124 additions & 2 deletions tests/test_complex_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.sql.sqltypes import Integer, String

from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps
from datapipe.datatable import DataStore
from datapipe.datatable import DataStore, DataTable
from datapipe.step.batch_generate import BatchGenerate
from datapipe.step.batch_transform import BatchTransform
from datapipe.store.database import TableStoreDB
Expand Down Expand Up @@ -324,7 +324,6 @@ def complex_transform_with_many_recordings(dbconn, N: int):
[
Column("image_id", Integer, primary_key=True),
Column("model_id", Integer, primary_key=True),
Column("prediction__attribite", Integer),
],
True,
)
Expand Down Expand Up @@ -445,3 +444,126 @@ def test_complex_transform_with_many_recordings_N1000(dbconn):
@pytest.mark.skip(reason="fails on sqlite")
def test_complex_transform_with_many_recordings_N10000(dbconn):
complex_transform_with_many_recordings(dbconn, N=10000)


def test_applying_prediction_on_best_model_only(dbconn):
N = 100
ds = DataStore(dbconn, create_meta_table=True)
catalog = Catalog(
{
"tbl_image": Table(
store=TableStoreDB(
dbconn,
"tbl_image",
[
Column("image_id", Integer, primary_key=True),
],
True,
)
),
"tbl_model": Table(
store=TableStoreDB(
dbconn,
"tbl_model",
[
Column("model_id", Integer, primary_key=True),
],
True,
)
),
"tbl_best_model": Table(
store=TableStoreDB(
dbconn,
"tbl_best_model",
[
Column("model_id", Integer, primary_key=True),
],
True,
)
),
"tbl_prediction": Table(
store=TableStoreDB(
dbconn,
"tbl_prediction",
[
Column("image_id", Integer, primary_key=True),
Column("model_id", Integer, primary_key=True),
],
True,
)
),
}
)

def gen_tbls(df1, df2, df3):
yield df1, df2, df3

test_df__image = pd.DataFrame({"image_id": range(N)})
test_df__model = pd.DataFrame(
{
"model_id": [0, 1, 2, 3, 4]
}
)
test_df__best_model = pd.DataFrame({"model_id": [4]})

def inference_only_on_best_model(
df__image: pd.DataFrame,
df__model: pd.DataFrame,
df__best_model: pd.DataFrame,
idx: IndexDF,
):
assert all([model_id == 4 for model_id in idx["model_id"]])
df__prediction = pd.merge(df__image, df__model, how="cross")
return df__prediction[["image_id", "model_id"]]

pipeline = Pipeline(
[
BatchGenerate(
func=gen_tbls,
outputs=[
"tbl_image",
"tbl_model",
"tbl_best_model",
],
kwargs=dict(
df1=test_df__image,
df2=test_df__model,
df3=test_df__best_model,
),
),
BatchTransform(
func=inference_only_on_best_model,
inputs=[
"tbl_image", # image_id
"tbl_model", # model_id
Required("tbl_best_model"), # model_id
],
outputs=["tbl_prediction"],
transform_keys=["image_id", "model_id"],
),
]
)
steps = build_compute(ds, catalog, pipeline)
run_steps(ds, steps)
test__df_prediction = pd.DataFrame(
{"image_id": range(N), "model_id": [4] * N}
)
assert_df_equal(
ds.get_table("tbl_prediction").get_data(),
test__df_prediction,
index_cols=["image_id", "model_id"],
)

test_df__new_best_model = pd.DataFrame({"model_id": [3]})
dt__tbl_best_model: DataTable = ds.get_table("tbl_best_model")
dt__tbl_best_model.delete_by_idx(dt__tbl_best_model.get_data())
dt__tbl_best_model.store_chunk(test_df__new_best_model)
run_steps(ds, steps)
test__new_df_prediction = pd.DataFrame(
{"image_id": range(N), "model_id": [3] * N}
)
assert_df_equal(
ds.get_table("tbl_prediction").get_data(),
test__new_df_prediction,
index_cols=["image_id", "model_id"],
)
Loading