Source code for pytest_experiments.store

import datetime as dt
from typing import List
from sqlalchemy import (
    create_engine,
    select,
    Column,
    Integer,
    Text,
    JSON,
    DateTime,
)
from sqlalchemy.orm import declarative_base, Session
from .config import EXPERIMENT_TABLENAME
from .common import mark_utc
from .json_tools import json_serializer, json_deserializer


Base = declarative_base()


[docs]class ExperimentModel(Base): """The data model for an experiment. This model leverages the `JSON` datatype which is currently only supported by the following backends: - PostgreSQL - MySQL - SQLite as of version 3.9 - Microsoft SQL Server 2016 and later See https://docs.sqlalchemy.org/en/14/core/type_basics.html?highlight=data%20types#sqlalchemy.types.JSON # noqa """ __tablename__ = EXPERIMENT_TABLENAME id = Column( "id", Integer, primary_key=True, autoincrement=True, comment="A unique identifier for an experiment run", ) start_time = Column( "start_time", DateTime, comment="The UTC timestamp of the experiment start", ) end_time = Column( "end_time", DateTime, comment="The UTC timestamp of the experiment end" ) name = Column( "name", Text, nullable=False, comment="The name of the experiment" ) outcome = Column( "outcome", Text, nullable=False, comment="The outcome of the experiment", ) parameters = Column( "parameters", JSON, nullable=False, comment="The experiment input parameters", ) data = Column( "data", JSON, nullable=False, comment="Data collected during the experiment", ) @property def start_time_tz(self) -> dt.datetime: """The timezone-aware experiment start timestamp.""" return mark_utc(self.start_time) @property def end_time_tz(self) -> dt.datetime: """The timezone-aware experiment end timestamp.""" return mark_utc(self.end_time)
[docs]def initialize_database(engine): """Initialize a database with the experiments table.""" return Base.metadata.create_all(engine)
[docs]class StorageManager: def __init__(self, db_uri: str) -> None: self._db_uri = db_uri self.engine = create_engine( db_uri, json_serializer=json_serializer, json_deserializer=json_deserializer, future=True, ) initialize_database(self.engine) @property def db_uri(self) -> str: return self._db_uri
[docs] def create_session(self) -> Session: """Create a database session.""" return Session(self.engine)
[docs] def record_experiment(self, experiment: ExperimentModel): """Record an experiment to the database. Args: experiment (ExperimentModel): The experiment to record. """ with self.create_session() as session, session.begin(): session.add(experiment)
[docs] def get_all_experiments(self) -> List[ExperimentModel]: """Return all experiments in the database.""" with self.create_session() as session: return session.execute(select(ExperimentModel)).scalars().all()