"""Schemas for classical MD package."""

from __future__ import annotations

import zlib
from dataclasses import dataclass
from datetime import datetime

from monty.json import MSONable
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    PlainSerializer,
    PlainValidator,
    WithJsonSchema,
)
from pymatgen.core import Structure
from typing_extensions import Annotated

from emmet.core.utils import utcnow
from emmet.core.vasp.task_valid import TaskState


def compressed_str_validator(s: str) -> str:
    try:
        compressed_bytes = bytes.fromhex(s)
        decompressed_bytes = zlib.decompress(compressed_bytes)
        return decompressed_bytes.decode("utf-8")
    except:  # noqa
        return s


def compressed_str_serializer(s: str) -> str:
    decompressed_bytes = s.encode("utf-8")
    return zlib.compress(decompressed_bytes).hex()


# this type will take a string and automatically compress and
# decompress it when it is serialized and deserialized
CompressedStr = Annotated[
    str,
    PlainValidator(compressed_str_validator),
    PlainSerializer(compressed_str_serializer),
    WithJsonSchema({"type": "string"}),
]


@dataclass
class MoleculeSpec(MSONable):
    """A molecule schema to be output by OpenMMGenerators."""

    name: str
    count: int
    charge_scaling: float
    charge_method: str
    openff_mol: str  # a tk.Molecule object serialized with to_json


class MDTaskDocument(BaseModel):  # type: ignore[call-arg]
    """Definition of the OpenMM task document."""

    tags: list[str] | None = Field(
        [], title="tag", description="Metadata tagged to a given task."
    )

    dir_name: str | None = Field(None, description="The directory for this MD task")

    state: TaskState | None = Field(None, description="State of this calculation")

    job_uuids: list | None = Field(
        None,
        description="The job_uuids for all contributing jobs, this will only"
        "have a value if the taskdoc is generated by a Flow.",
    )

    calcs_reversed: list | None = Field(
        None,
        title="Calcs reversed data",
        description="Detailed data for each MD calculation contributing to "
        "the task document.",
    )

    interchange: CompressedStr | None = Field(
        None, description="An interchange object serialized to json."
    )

    mol_specs: list[MoleculeSpec] | None = Field(
        None,
        description="Molecules within the system. Only makes sense "
        "for molecular systems.",
    )

    structure: Structure | None = Field(
        None,
        title="Structure",
        description="The final structure for the simulation. Saved only "
        "if specified by job.",
    )

    force_field: str | None = Field(None, description="The classical MD forcefield.")

    task_type: str | None = Field(None, description="The type of calculation.")

    # task_label: str | None= Field(None, description="A description of the task")
    # TODO: where does task_label get added

    last_updated: datetime = Field(
        default_factory=utcnow,
        description="Timestamp for the most recent calculation for this task document",
    )

    model_config = ConfigDict(extra="allow")


class ClassicalMDTaskDocument(MDTaskDocument):
    """Definition of the OpenMM task document."""

    mol_specs: list[MoleculeSpec] | None = Field(
        None, description="Molecules within the system."
    )
