diff --git a/pyproject.toml b/pyproject.toml index 3a34c0c..a10c78b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,20 +5,24 @@ build-backend = "setuptools.build_meta" [project] name = "sensgw" version = "0.1.0" -description = "My Python project" +description = "Sensor gateway" readme = "README.md" requires-python = ">=3.8" license = {text = "MIT"} authors = [{name = "Your Name", email = "you@example.com"}] classifiers = [ - "Programming Language :: Python :: 3", - "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3", + "Operating System :: POSIX :: Linux", ] -dependencies = [ # Your runtime deps, e.g. - "requests>=2.25.0", - "paho-mqtt>=2.0.0" +dependencies = [ + "asyncpg>=0.31.0", + "requests>=2.25.0", + "paho-mqtt>=2.0.0", + "pysnmp>=7.1.22", + "pymeasure>=0.15.0", + "PyVISA-py>=0.8.1" ] -[project.scripts] # Entry points for CLI scripts (PEP 621) -my-script = "sensgw.main:main" +[project.scripts] +sensgw = "sensgw.main:main" diff --git a/sensgw.out b/sensgw.out new file mode 100644 index 0000000..7ef5e7e --- /dev/null +++ b/sensgw.out @@ -0,0 +1,6 @@ +nohup: ignoring input +INFO:sensgw:Running with 4 task(s) +WARNING:pymeasure.adapters.vxi11:Failed to import vxi11 package, which is required for the VXI11Adapter +/home/mira/sensgw/venv/lib/python3.13/site-packages/pymeasure/instruments/generic_types.py:110: FutureWarning: It is not known whether this device support SCPI commands or not. Please inform the pymeasure maintainers if you know the answer. + warn("It is not known whether this device support SCPI commands or not. Please inform " +INFO:pymeasure.instruments.instrument:Initializing Keithley 2000 Multimeter. diff --git a/sensgw/config.py b/sensgw/config.py new file mode 100644 index 0000000..66706c5 --- /dev/null +++ b/sensgw/config.py @@ -0,0 +1,24 @@ +# sensgw/config.py +from dataclasses import dataclass +import os + + +@dataclass(frozen=True) +class Config: + db_dsn: str + log_level: str = "INFO" + registry_refresh_s: int = 60 + default_poll_interval_s: int = 30 + + +def load_config() -> Config: + dsn = os.environ.get("SENSGW_DB_DSN") + if not dsn: + raise RuntimeError("Missing SENSGW_DB_DSN") + return Config( + db_dsn=dsn, + log_level=os.environ.get("SENSGW_LOG_LEVEL", "INFO"), + registry_refresh_s=int(os.environ.get("SENSGW_REGISTRY_REFRESH_S", "60")), + default_poll_interval_s=int(os.environ.get("SENSGW_DEFAULT_POLL_S", "30")), + ) + diff --git a/sensgw/db.py b/sensgw/db.py new file mode 100644 index 0000000..35292d5 --- /dev/null +++ b/sensgw/db.py @@ -0,0 +1,39 @@ +# sensgw/db.py +import asyncpg +import json +from typing import Optional + + +async def _init_connection(con: asyncpg.Connection) -> None: + await con.set_type_codec( + "json", + encoder=json.dumps, + decoder=json.loads, + schema="pg_catalog", + ) + await con.set_type_codec( + "jsonb", + encoder=json.dumps, + decoder=json.loads, + schema="pg_catalog", + ) + + +class Database: + def __init__(self, dsn: str): + self._dsn = dsn + self.pool: Optional[asyncpg.Pool] = None + + async def start(self) -> None: + self.pool = await asyncpg.create_pool( + dsn=self._dsn, + min_size=1, + max_size=10, + init=_init_connection, + ) + + async def stop(self) -> None: + if self.pool: + await self.pool.close() + self.pool = None + diff --git a/sensgw/main.py b/sensgw/main.py index e69de29..aa5981d 100644 --- a/sensgw/main.py +++ b/sensgw/main.py @@ -0,0 +1,130 @@ +# sensgw/main.py +from __future__ import annotations + +import asyncio +import logging + +from .config import load_config +from .db import Database +from .registry import load_registry +from .writer import Writer + +from .protocols.mqtt import MqttCollector, MqttBinding +from .protocols.prologix import PrologixEndpointCollector, PrologixBinding +from .protocols.snmp import SnmpEndpointCollector, SnmpBinding +# from .protocols.visa import VisaCollector, VisaBinding + + +async def _run() -> None: + cfg = load_config() + logging.basicConfig(level=getattr(logging, cfg.log_level.upper(), logging.INFO)) + log = logging.getLogger("sensgw") + + db = Database(cfg.db_dsn) + await db.start() + writer = Writer(db) + + try: + reg = await load_registry(db) + by_proto = reg.channels_by_protocol() + + tasks: list[asyncio.Task[None]] = [] + + # --- MQTT (assumes one broker; if multiple, split by endpoint_id) --- + mqtt_bindings: list[MqttBinding] = [] + for ep, dev, ch in by_proto.get("mqtt", []): + src = ch.source + if src.get("type") != "mqtt_topic": + continue + mqtt_bindings.append( + MqttBinding( + endpoint=ep, + device=dev, + channel=ch, + topic=str(src["topic"]), + field=str(src.get("field", "")), + payload=str(src.get("payload", "json")), + ) + ) + + if mqtt_bindings: + mqttc = MqttCollector(writer) + tasks.append(asyncio.create_task(mqttc.run(mqtt_bindings), name="mqtt")) + + # --- SNMP (one task per endpoint) --- + snmp_bindings_by_ep: dict[int, tuple[object, list[SnmpBinding]]] = {} + for ep, dev, ch in by_proto.get("snmp", []): + src = ch.source + if src.get("type") != "snmp_oid": + continue + + b = SnmpBinding( + endpoint=ep, + device=dev, + channel=ch, + oid=str(src["oid"]), + datatype=str(src.get("datatype", "float")), + ) + snmp_bindings_by_ep.setdefault(ep.endpoint_id, (ep, []))[1].append(b) + + snmpc = SnmpEndpointCollector(writer, default_poll_s=cfg.default_poll_interval_s) + for _ep_id, (ep, bindings) in snmp_bindings_by_ep.items(): + tasks.append( + asyncio.create_task( + snmpc.run_endpoint(ep, bindings), + name=f"snmp:{ep.endpoint_key}", + ) + ) + + # --- Prologix (one task per channel/binding) --- + prolc = PrologixEndpointCollector(writer, default_poll_s=cfg.default_poll_interval_s) + for ep, dev, ch in by_proto.get("prologix", []): + src = ch.source + if src.get("type") != "scpi": + continue + b = PrologixBinding( + endpoint=ep, + device=dev, + channel=ch, + query=str(src["query"]), + datatype=str(src.get("datatype", "float")), + ) + tasks.append( + asyncio.create_task( + prolc.run_binding(b), + name=f"prologix:{dev.device_id}:{ch.metric}", + ) + ) + + # --- VISA --- + # visac = VisaCollector(writer, default_poll_s=cfg.default_poll_interval_s) + # for ep, dev, ch in by_proto.get("visa", []): + # src = ch.source + # if src.get("type") != "scpi": + # continue + # b = VisaBinding( + # endpoint=ep, + # device=dev, + # channel=ch, + # query=str(src["query"]), + # datatype=str(src.get("datatype", "float")), + # ) + # tasks.append( + # asyncio.create_task( + # visac.run_binding(b), + # name=f"visa:{dev.device_id}:{ch.metric}", + # ) + # ) + + if not tasks: + log.warning("No enabled channels found. Exiting.") + return + + log.info("Running with %d task(s)", len(tasks)) + await asyncio.gather(*tasks) + finally: + await db.stop() + + +def main() -> None: + asyncio.run(_run()) diff --git a/sensgw/metrics.py b/sensgw/metrics.py new file mode 100644 index 0000000..3ac8a91 --- /dev/null +++ b/sensgw/metrics.py @@ -0,0 +1,13 @@ +ALLOWED_METRICS = { + "temp_c", + "humidity_rh", + "pressure_pa", + "light_lux", + "soil_moist", + "co2_ppm", + "voltage_v", + "current_a", + "resistance_ohm", + "freq_hz", + "power_w", +} diff --git a/sensgw/models.py b/sensgw/models.py new file mode 100644 index 0000000..2b4721c --- /dev/null +++ b/sensgw/models.py @@ -0,0 +1,35 @@ +# sensgw/models.py +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass(frozen=True) +class Endpoint: + endpoint_id: int + endpoint_key: str + protocol: str + conn: Dict[str, Any] + is_enabled: bool + + +@dataclass(frozen=True) +class Device: + device_id: int + device_key: str + endpoint_id: Optional[int] + location_id: Optional[int] + is_enabled: bool + metadata: Dict[str, Any] + + +@dataclass(frozen=True) +class Channel: + channel_id: int + device_id: int + metric: str + source: Dict[str, Any] + scale_value: float + offset_value: float + poll_interval_s: Optional[int] + is_enabled: bool + diff --git a/sensgw/protocols/mqtt.py b/sensgw/protocols/mqtt.py new file mode 100644 index 0000000..bd84e42 --- /dev/null +++ b/sensgw/protocols/mqtt.py @@ -0,0 +1,132 @@ +# sensgw/protocols/mqtt.py +from __future__ import annotations + +import asyncio +import datetime as dt +import json +from dataclasses import dataclass +from typing import Any + +import paho.mqtt.client as mqtt + +from ..models import Endpoint, Device, Channel +from ..writer import Writer + + +@dataclass(frozen=True) +class MqttBinding: + endpoint: Endpoint + device: Device + channel: Channel + topic: str + field: str + payload: str # "json" | "text" + + +class MqttCollector: + def __init__(self, writer: Writer): + self.writer = writer + self._queue: asyncio.Queue[tuple[str, bytes, dt.datetime]] = asyncio.Queue() + self._client: mqtt.Client | None = None + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage) -> None: + ts = dt.datetime.now(dt.timezone.utc) + loop: asyncio.AbstractEventLoop = userdata["loop"] + loop.call_soon_threadsafe( + self._queue.put_nowait, + (str(msg.topic), bytes(msg.payload), ts), + ) + + @staticmethod + def _extract_numeric(*, payload_kind: str, field: str, payload: bytes) -> float: + kind = (payload_kind or "json").strip().lower() + fld = (field or "").strip() + text = payload.decode("utf-8", errors="replace").strip() + + if kind == "text": + return float(text) + + if kind != "json": + raise ValueError(f"unsupported payload kind: {kind}") + + obj = json.loads(text) + + # If field is provided, expect a JSON object/dict + if fld: + if not isinstance(obj, dict): + raise ValueError(f"expected JSON object for field='{fld}', got {type(obj).__name__}") + if fld not in obj: + raise KeyError(f"missing field '{fld}'") + return float(obj[fld]) + + # No field => accept JSON scalar (number/string) or object with common keys + if isinstance(obj, (int, float, str)): + return float(obj) + if isinstance(obj, dict): + for k in ("value", "val", "v"): + if k in obj: + return float(obj[k]) + raise KeyError("no field specified and no default key found in JSON object") + raise ValueError(f"unsupported JSON type: {type(obj).__name__}") + + async def run(self, bindings: list[MqttBinding]) -> None: + if not bindings: + return + + ep = bindings[0].endpoint + host = ep.conn.get("host", "localhost") + port = int(ep.conn.get("port", 1883)) + client_id = ep.conn.get("client_id", "sensgw") + + loop = asyncio.get_running_loop() + c = mqtt.Client( + mqtt.CallbackAPIVersion.VERSION2, + client_id=client_id, + userdata={"loop": loop}, + ) + c.on_message = self._on_message + + # Optional auth + if "username" in ep.conn: + c.username_pw_set(ep.conn["username"], ep.conn.get("password")) + + c.connect(host, port, keepalive=30) + + topic_to_bindings: dict[str, list[MqttBinding]] = {} + for b in bindings: + topic_to_bindings.setdefault(b.topic, []).append(b) + + for t in sorted(topic_to_bindings.keys()): + c.subscribe(t) + + c.loop_start() + self._client = c + + try: + while True: + topic, payload, ts = await self._queue.get() + for b in topic_to_bindings.get(topic, []): + try: + value = self._extract_numeric( + payload_kind=b.payload, + field=b.field, + payload=payload, + ) + value = value * b.channel.scale_value + b.channel.offset_value + + await self.writer.write_metric( + ts=ts, + device_id=b.device.device_id, + location_id=b.device.location_id, + metric=b.channel.metric, + value=value, + ) + except Exception as e: + await self.writer.write_error( + device_id=b.device.device_id, + error=f"mqtt parse/write: {e}", + ) + finally: + c.loop_stop() + c.disconnect() + diff --git a/sensgw/protocols/polling.py b/sensgw/protocols/polling.py new file mode 100644 index 0000000..c8c940c --- /dev/null +++ b/sensgw/protocols/polling.py @@ -0,0 +1,26 @@ +# sensgw/protocols/polling.py +import asyncio +import datetime as dt +from typing import Awaitable, Callable, Optional + + +async def poll_forever( + *, + interval_s: int, + read_once: Callable[[], Awaitable[None]], + jitter_s: float = 0.0, + stop_event: Optional[asyncio.Event] = None, +) -> None: + if jitter_s: + await asyncio.sleep(jitter_s) + + while True: + if stop_event and stop_event.is_set(): + return + start = dt.datetime.now(dt.timezone.utc) + try: + await read_once() + finally: + elapsed = (dt.datetime.now(dt.timezone.utc) - start).total_seconds() + sleep_s = max(0.0, interval_s - elapsed) + await asyncio.sleep(sleep_s) diff --git a/sensgw/protocols/prologix.py b/sensgw/protocols/prologix.py new file mode 100644 index 0000000..a2de174 --- /dev/null +++ b/sensgw/protocols/prologix.py @@ -0,0 +1,186 @@ +# sensgw/protocols/prologix.py +from __future__ import annotations + +import asyncio +import datetime as dt +import threading +from dataclasses import dataclass +from typing import Any + +from ..models import Endpoint, Device, Channel +from ..writer import Writer +from .polling import poll_forever + + +@dataclass(frozen=True) +class PrologixBinding: + endpoint: Endpoint + device: Device + channel: Channel + query: str + datatype: str # "float" | "int" | ... + + +def _parse_numeric(datatype: str, raw: str) -> float: + kind = (datatype or "float").strip().lower() + if kind == "int": + return float(int(raw)) + # default: float + return float(raw) + + +def _driver_class(driver_key: str) -> type[Any] | None: + """ + Map driver keys stored in DB to PyMeasure instrument classes. + + devices.metadata.driver examples: + - "keithley2000" + """ + key = (driver_key or "").strip().lower() + if not key: + return None + + if key in {"keithley2000", "keithley_2000", "keithley:2000"}: + from pymeasure.instruments.keithley import Keithley2000 # type: ignore + + return Keithley2000 + + # Add more mappings here as you add support. + return None + + +class PrologixEndpointClient: + """ + One shared Prologix adapter per endpoint, protected by a lock because it is stateful + (address switching) and not safe to use concurrently. + + If a device specifies devices.metadata.driver, we create a PyMeasure Instrument on top + of the same adapter and run queries through instrument.ask(). + """ + + def __init__(self, endpoint: Endpoint): + self.endpoint = endpoint + self._lock = threading.Lock() + self._adapter: Any | None = None + self._instruments: dict[tuple[int, str], Any] = {} + + def _get_adapter(self) -> Any: + if self._adapter is None: + from pymeasure.adapters import PrologixAdapter # type: ignore + + try: + resource = self.endpoint.conn["resource"] + except KeyError as e: + raise RuntimeError( + f"Missing endpoint.conn['resource'] for endpoint_id={self.endpoint.endpoint_id}" + ) from e + + read_timeout = int(self.endpoint.conn.get("gpib_read_timeout_ms", 200)) + auto = bool(self.endpoint.conn.get("auto", False)) + + self._adapter = PrologixAdapter( + resource, + gpib_read_timeout=read_timeout, + auto=auto, + ) + try: + self._adapter.flush_read_buffer() + except Exception: + pass + return self._adapter + + def _get_instrument(self, *, gpib_addr: int, driver_key: str) -> Any: + """ + Cached per (addr, driver_key). Uses the shared adapter. + """ + key = (gpib_addr, driver_key.strip().lower()) + inst = self._instruments.get(key) + if inst is not None: + return inst + + cls = _driver_class(driver_key) + if cls is None: + raise KeyError(f"Unknown driver '{driver_key}'") + + ad = self._get_adapter() + # Ensure the adapter is pointed at the correct instrument when the driver is constructed. + ad.address = gpib_addr + + inst = cls(ad) + self._instruments[key] = inst + return inst + + def query(self, *, gpib_addr: int, cmd: str, driver_key: str | None = None) -> str: + """ + Execute a query at a given GPIB address. + If driver_key is provided and known, execute via driver (instrument.ask). + Otherwise, raw adapter write/read. + """ + with self._lock: + ad = self._get_adapter() + ad.address = gpib_addr + + if driver_key: + inst = self._get_instrument(gpib_addr=gpib_addr, driver_key=driver_key) + # Keep the endpoint lock held across ask(); it may do multiple I/O ops. + return str(inst.ask(cmd)).strip() + + ad.write(cmd) + return str(ad.read()).strip() + + +class PrologixEndpointCollector: + def __init__(self, writer: Writer, default_poll_s: int): + self.writer = writer + self.default_poll_s = default_poll_s + self._clients: dict[int, PrologixEndpointClient] = {} + + def _client(self, endpoint: Endpoint) -> PrologixEndpointClient: + client = self._clients.get(endpoint.endpoint_id) + if client is None: + client = PrologixEndpointClient(endpoint) + self._clients[endpoint.endpoint_id] = client + return client + + async def run_binding(self, b: PrologixBinding) -> None: + interval_s = int(b.channel.poll_interval_s or self.default_poll_s) + client = self._client(b.endpoint) + + gpib_addr = b.device.metadata.get("gpib_addr") + if gpib_addr is None: + raise RuntimeError( + f"Missing device.metadata.gpib_addr for device_id={b.device.device_id}" + ) + gpib_addr = int(gpib_addr) + + driver_key = b.device.metadata.get("driver") + driver_key = str(driver_key).strip() if driver_key else None + + async def read_once() -> None: + ts = dt.datetime.now(dt.timezone.utc) + try: + raw = await asyncio.to_thread( + client.query, + gpib_addr=gpib_addr, + cmd=b.query, + driver_key=driver_key, + ) + + v = _parse_numeric(b.datatype, raw) + v = v * b.channel.scale_value + b.channel.offset_value + + await self.writer.write_metric( + ts=ts, + device_id=b.device.device_id, + location_id=b.device.location_id, + metric=b.channel.metric, + value=v, + ) + except Exception as e: + await self.writer.write_error( + device_id=b.device.device_id, + error=f"prologix: {e}", + ) + + await poll_forever(interval_s=interval_s, read_once=read_once) + diff --git a/sensgw/protocols/snmp.py b/sensgw/protocols/snmp.py new file mode 100644 index 0000000..5ef664d --- /dev/null +++ b/sensgw/protocols/snmp.py @@ -0,0 +1,152 @@ +# sensgw/protocols/snmp.py +from __future__ import annotations + +import datetime as dt +from dataclasses import dataclass + +from ..models import Endpoint, Device, Channel +from ..writer import Writer +from .polling import poll_forever + + +@dataclass(frozen=True) +class SnmpBinding: + endpoint: Endpoint + device: Device + channel: Channel + oid: str + datatype: str # "float" | "int" | ... + + +def _parse_numeric(datatype: str, raw: str) -> float: + kind = (datatype or "float").strip().lower() + if kind == "int": + return float(int(raw)) + return float(raw) + + +def _parse_version(conn: dict) -> int: + """ + Return mpModel: + SNMPv1 -> 0 + SNMPv2c -> 1 + """ + v = str(conn.get("version", "2c")).lower() + if v in {"1", "v1", "snmpv1"}: + return 0 + return 1 + + +class SnmpEndpointCollector: + def __init__(self, writer: Writer, default_poll_s: int): + self.writer = writer + self.default_poll_s = default_poll_s + + async def _get_many( + self, + *, + host: str, + port: int, + community: str, + mp_model: int, + timeout_s: int, + oids: list[str], + ) -> dict[str, str]: + from pysnmp.hlapi.v3arch.asyncio import ( # type: ignore + SnmpEngine, + CommunityData, + UdpTransportTarget, + ContextData, + ObjectType, + ObjectIdentity, + get_cmd, + ) + + snmp_engine = SnmpEngine() + try: + var_binds = [ObjectType(ObjectIdentity(oid)) for oid in oids] + + # In pysnmp 7.x, target creation is async: + target = await UdpTransportTarget.create((host, port), timeout=timeout_s, retries=0) + + iterator = get_cmd( + snmp_engine, + CommunityData(community, mpModel=mp_model), + target, + ContextData(), + *var_binds, + ) + + error_indication, error_status, error_index, out_binds = await iterator + + if error_indication: + raise RuntimeError(str(error_indication)) + if error_status: + raise RuntimeError( + f"{error_status.prettyPrint()} at " + f"{out_binds[int(error_index) - 1][0] if error_index else '?'}" + ) + + return {str(name): str(val) for name, val in out_binds} + finally: + snmp_engine.close_dispatcher() + + async def run_endpoint(self, endpoint: Endpoint, bindings: list[SnmpBinding]) -> None: + host = endpoint.conn["host"] + port = int(endpoint.conn.get("port", 161)) + community = endpoint.conn.get("community", "public") + timeout_s = int(endpoint.conn.get("timeout_s", 2)) + mp_model = _parse_version(endpoint.conn) + + intervals = [ + int(b.channel.poll_interval_s) + for b in bindings + if b.channel.poll_interval_s is not None + ] + interval_s = min(intervals) if intervals else self.default_poll_s + + oid_to_binding: dict[str, SnmpBinding] = {b.oid.strip(): b for b in bindings} + oids = list(oid_to_binding.keys()) + + async def read_once() -> None: + ts = dt.datetime.now(dt.timezone.utc) + try: + values = await self._get_many( + host=host, + port=port, + community=community, + mp_model=mp_model, + timeout_s=timeout_s, + oids=oids, + ) + + for oid_str, raw in values.items(): + b = oid_to_binding.get(oid_str) + if b is None: + continue + try: + v = _parse_numeric(b.datatype, raw) + v = v * b.channel.scale_value + b.channel.offset_value + + await self.writer.write_metric( + ts=ts, + device_id=b.device.device_id, + location_id=b.device.location_id, + metric=b.channel.metric, + value=v, + ) + except Exception as e: + await self.writer.write_error( + device_id=b.device.device_id, + error=f"snmp parse/write: {e}", + ) + except Exception as e: + # Endpoint-level failure: mark all devices as error + for b in bindings: + await self.writer.write_error( + device_id=b.device.device_id, + error=f"snmp endpoint: {e}", + ) + + await poll_forever(interval_s=interval_s, read_once=read_once) + diff --git a/sensgw/protocols/visa.py b/sensgw/protocols/visa.py new file mode 100644 index 0000000..ee2bebb --- /dev/null +++ b/sensgw/protocols/visa.py @@ -0,0 +1,84 @@ +# sensgw/protocols/visa.py +from __future__ import annotations + +import asyncio +import datetime as dt +from dataclasses import dataclass + +from ..models import Endpoint, Device, Channel +from ..writer import Writer +from .polling import poll_forever + + +@dataclass(frozen=True) +class VisaBinding: + endpoint: Endpoint + device: Device + channel: Channel + query: str + datatype: str # "float" etc + + +def _visa_query_sync(*, resource: str, conn: dict, device_meta: dict, query: str) -> str: + import pyvisa # type: ignore + + rm = pyvisa.ResourceManager() + inst = rm.open_resource(resource) + + # Optional serial config + if "baud_rate" in conn and hasattr(inst, "baud_rate"): + inst.baud_rate = int(conn["baud_rate"]) + + if "read_termination" in conn: + inst.read_termination = str(conn["read_termination"]) + if "write_termination" in conn: + inst.write_termination = str(conn["write_termination"]) + + # If you're using a Prologix-like controller over serial, you may need to set addr. + # This is device-specific; keeping it optional: + gpib_addr = device_meta.get("gpib_addr") + if gpib_addr is not None: + inst.write(f"++addr {int(gpib_addr)}") + + return str(inst.query(query)).strip() + + +class VisaCollector: + def __init__(self, writer: Writer, default_poll_s: int): + self.writer = writer + self.default_poll_s = default_poll_s + + async def run_binding(self, b: VisaBinding) -> None: + ep = b.endpoint + resource = ep.conn["resource"] + interval_s = int(b.channel.poll_interval_s or self.default_poll_s) + + async def read_once() -> None: + try: + raw = await asyncio.to_thread( + _visa_query_sync, + resource=resource, + conn=ep.conn, + device_meta=b.device.metadata, + query=b.query, + ) + if b.datatype == "float": + value = float(raw) + elif b.datatype == "int": + value = float(int(raw)) + else: + value = float(raw) + + value = value * b.channel.scale_value + b.channel.offset_value + ts = dt.datetime.now(dt.timezone.utc) + await self.writer.write_metric( + ts=ts, + device_id=b.device.device_id, + location_id=b.device.location_id, + metric=b.channel.metric, + value=value, + ) + except Exception as e: + await self.writer.write_error(device_id=b.device.device_id, error=f"visa: {e}") + + await poll_forever(interval_s=interval_s, read_once=read_once) diff --git a/sensgw/registry.py b/sensgw/registry.py new file mode 100644 index 0000000..a94ffec --- /dev/null +++ b/sensgw/registry.py @@ -0,0 +1,93 @@ +# sensgw/registry.py +from dataclasses import dataclass +from typing import Dict, List, Tuple + +from .models import Endpoint, Device, Channel +from .db import Database + + +@dataclass(frozen=True) +class Registry: + endpoints: Dict[int, Endpoint] + devices: Dict[int, Device] + channels: List[Channel] + + def channels_by_protocol(self) -> Dict[str, List[Tuple[Endpoint, Device, Channel]]]: + out: Dict[str, List[Tuple[Endpoint, Device, Channel]]] = {} + for ch in self.channels: + dev = self.devices.get(ch.device_id) + if not dev or not dev.is_enabled or dev.endpoint_id is None: + continue + ep = self.endpoints.get(dev.endpoint_id) + if not ep or not ep.is_enabled: + continue + out.setdefault(ep.protocol, []).append((ep, dev, ch)) + return out + + +async def load_registry(db: Database) -> Registry: + assert db.pool is not None + + async with db.pool.acquire() as con: + ep_rows = await con.fetch( + """ + select endpoint_id, endpoint_key, protocol, conn, is_enabled + from endpoints + where is_enabled = true + """ + ) + dev_rows = await con.fetch( + """ + select device_id, device_key, endpoint_id, location_id, is_enabled, metadata + from devices + where is_enabled = true + """ + ) + ch_rows = await con.fetch( + """ + select channel_id, device_id, metric, source, scale_value, offset_value, + poll_interval_s, is_enabled + from device_channels + where is_enabled = true + """ + ) + + endpoints = { + int(r["endpoint_id"]): Endpoint( + endpoint_id=int(r["endpoint_id"]), + endpoint_key=str(r["endpoint_key"]), + protocol=str(r["protocol"]), + conn=(r["conn"] or {}), + is_enabled=bool(r["is_enabled"]), + ) + for r in ep_rows + } + + devices = { + int(r["device_id"]): Device( + device_id=int(r["device_id"]), + device_key=str(r["device_key"]), + endpoint_id=(int(r["endpoint_id"]) if r["endpoint_id"] is not None else None), + location_id=(int(r["location_id"]) if r["location_id"] is not None else None), + is_enabled=bool(r["is_enabled"]), + metadata=(r["metadata"] or {}), + ) + for r in dev_rows + } + + channels = [ + Channel( + channel_id=int(r["channel_id"]), + device_id=int(r["device_id"]), + metric=str(r["metric"]), + source=(r["source"] or {}), + scale_value=float(r["scale_value"]), + offset_value=float(r["offset_value"]), + poll_interval_s=(int(r["poll_interval_s"]) if r["poll_interval_s"] is not None else None), + is_enabled=bool(r["is_enabled"]), + ) + for r in ch_rows + ] + + return Registry(endpoints=endpoints, devices=devices, channels=channels) + diff --git a/sensgw/writer.py b/sensgw/writer.py new file mode 100644 index 0000000..869e128 --- /dev/null +++ b/sensgw/writer.py @@ -0,0 +1,75 @@ +# sensgw/writer.py +from __future__ import annotations + +import datetime as dt +from typing import Optional + +from .db import Database +from .metrics import ALLOWED_METRICS + + +class Writer: + def __init__(self, db: Database): + self.db = db + + async def write_metric( + self, + *, + ts: dt.datetime, + device_id: int, + location_id: Optional[int], + metric: str, + value: float, + ) -> None: + if metric not in ALLOWED_METRICS: + raise ValueError(f"Metric not allowed: {metric}") + + assert self.db.pool is not None + # Safe because we validate metric against allow-list above. + col = metric + + async with self.db.pool.acquire() as con: + async with con.transaction(): + await con.execute( + f""" + insert into sensor_data (ts, device_id, location_id, {col}) + values ($1, $2, $3, $4) + on conflict (device_id, ts) do update + set {col} = excluded.{col}, + location_id = coalesce(excluded.location_id, sensor_data.location_id) + """, + ts, + device_id, + location_id, + value, + ) + await con.execute( + """ + insert into device_status (device_id, last_seen, last_ok, updated_at) + values ($1, now(), now(), now()) + on conflict (device_id) do update + set last_seen = excluded.last_seen, + last_ok = excluded.last_ok, + updated_at = excluded.updated_at, + last_error_at = null, + last_error = null + """, + device_id, + ) + + async def write_error(self, *, device_id: int, error: str) -> None: + assert self.db.pool is not None + async with self.db.pool.acquire() as con: + await con.execute( + """ + insert into device_status (device_id, last_seen, last_error_at, last_error, updated_at) + values ($1, now(), now(), $2, now()) + on conflict (device_id) do update + set last_seen = excluded.last_seen, + last_error_at = excluded.last_error_at, + last_error = excluded.last_error, + updated_at = excluded.updated_at + """, + device_id, + error[:2000], + )