"""Compile a Stan model extension module given code written in Stan.
These functions manage the process of compiling a Python extension module
from C++ code generated and loading the resulting module.
"""
import asyncio
import base64
import hashlib
import importlib
import importlib.resources
import logging
import platform
import sys
from importlib.machinery import EXTENSION_SUFFIXES
from pathlib import Path
from types import ModuleType
from typing import List, Optional, Tuple
import setuptools
import httpstan.build_ext
import httpstan.cache
import httpstan.compile
PACKAGE_DIR = Path(__file__).parent.resolve(strict=True)
logger = logging.getLogger("httpstan")
[docs]
def calculate_model_name(program_code: str) -> str:
"""Calculate model name from Stan program code.
Names look like this: ``models/2uxewutp``. Name uses a hash of the
concatenation of the following:
- UTF-8 encoded Stan program code
- UTF-8 encoded string recording the httpstan version
- UTF-8 encoded string identifying the system platform
- UTF-8 encoded string identifying the system bit architecture
- UTF-8 encoded string identifying the Python version
- UTF-8 encoded string identifying the Python executable
Arguments:
program_code: Stan program code.
Returns:
str: model name
"""
# digest_size of 5 means we expect a collision after a million models
digest_size = 5
hash = hashlib.blake2b(digest_size=digest_size)
hash.update(program_code.encode())
# system identifiers
hash.update(httpstan.__version__.encode())
hash.update(sys.platform.encode())
hash.update(str(sys.maxsize).encode())
hash.update(sys.version.encode())
# include sys.executable in hash to account for different `venv`s
hash.update(sys.executable.encode())
id = base64.b32encode(hash.digest()).decode().lower()
return f"models/{id}"
[docs]
def import_services_extension_module(model_name: str) -> ModuleType:
"""Load an existing model-specific stan::services extension module.
Arguments:
model_name
Returns:
module: loaded module handle.
Raises:
KeyError: Model not found.
"""
model_directory = httpstan.cache.model_directory(model_name)
try:
module_path = next(filter(lambda p: p.suffix in EXTENSION_SUFFIXES, model_directory.iterdir()))
except (FileNotFoundError, StopIteration):
raise KeyError(f"No module for `{model_name}` found in `{model_directory}`")
# The module name, which is independent of the filename, is always "stan_services". The module
# name must be defined in stan_services.cpp, which is compiled before we know with which
# specific stan model it will be linked with. Since we want to compile stan_services.cpp in
# advance, we are stuck with a fixed module name.
spec = importlib.util.spec_from_file_location("stan_services", module_path) # type: ignore
module: ModuleType = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module) # type: ignore
return module
[docs]
async def build_services_extension_module(program_code: str, extra_compile_args: Optional[List[str]] = None) -> str:
"""Compile a model-specific stan::services extension module.
Since compiling an extension module takes a long time, compilation takes
place in a different thread.
Messages generated by the compiler—normally sent to stderr—are collected
and saved. These messages are returned by the function.
Returns compiler messages.
This is a coroutine function.
IMPORTANT NOTE: This function builds the extension module in the cache
directory, making it available for later `import`ing. This "side-effect" is
why there are no functions called `load_services_extension_module` and
`dump_services_extension_module`.
"""
model_name = calculate_model_name(program_code)
model_directory_path = httpstan.cache.model_directory(model_name)
model_directory_path.mkdir(parents=True, exist_ok=True)
stan_model_name = f"model_{model_name.split('/')[1]}"
cpp_code, _ = httpstan.compile.compile(program_code, stan_model_name)
cpp_code_path = model_directory_path / f"{stan_model_name}.cpp"
with cpp_code_path.open("w") as fh:
fh.write(cpp_code)
include_dirs = [
str(model_directory_path),
str(PACKAGE_DIR / "include"),
]
stan_macros: List[Tuple[str, Optional[str]]] = [
("BOOST_DISABLE_ASSERTS", None),
("BOOST_PHOENIX_NO_VARIADIC_EXPRESSION", None),
("STAN_THREADS", None),
("_REENTRANT", None), # required by stan math / std:lgamma
# the following is needed on linux for compatibility with libraries built with the manylinux2014 image
("_GLIBCXX_USE_CXX11_ABI", "0"),
]
if extra_compile_args is None:
extra_compile_args = [
"-O3",
"-std=c++14",
"-Wno-sign-compare",
]
# Note: `library_dirs` is only relevant for linking. It does not tell an extension
# where to find shared libraries during execution. There are two ways for an
# extension module to find shared libraries: LD_LIBRARY_PATH and rpath.
libraries = ["sundials_cvodes", "sundials_idas", "sundials_nvecserial", "tbb"]
if platform.system() == "Darwin": # pragma: no cover
libraries.extend(["tbbmalloc", "tbbmalloc_proxy"])
extension = setuptools.Extension(
f"stan_services_{stan_model_name}", # filename only. Module name is "stan_services"
language="c++",
sources=[str(cpp_code_path)],
define_macros=stan_macros,
include_dirs=include_dirs,
library_dirs=[str(PACKAGE_DIR / "lib")],
libraries=libraries,
extra_compile_args=extra_compile_args,
extra_link_args=[f"-Wl,-rpath,{PACKAGE_DIR / 'lib'}"],
extra_objects=[
str((PACKAGE_DIR / "stan_services.cpp").with_suffix(".o")),
],
)
extensions = [extension]
build_lib = str(model_directory_path)
# Building the model takes a long time. Run in a different thread.
compiler_output = await asyncio.get_running_loop().run_in_executor(
None, httpstan.build_ext.run_build_ext, extensions, build_lib
)
return compiler_output