Forráskód Böngészése

Merge branch 'socket-abstraction' into master

Oxan van Leeuwen 3 éve
szülő
commit
487d02f303

+ 23 - 3
README.md

@@ -1,8 +1,8 @@
 Stream server for ESPHome
 =========================
 
-Custom component for ESPHome to expose a UART stream over WiFi or Ethernet. Can be used as a serial-to-wifi bridge as
-known from ESPLink or ser2net by using ESPHome.
+Custom component for ESPHome to expose a UART stream over WiFi or Ethernet. Provides a serial-to-wifi bridge as known
+from ESPLink or ser2net, using ESPHome.
 
 This component creates a TCP server listening on port 6638 (by default), and relays all data between the connected
 clients and the serial port. It doesn't support any control sequences, telnet options or RFC 2217, just raw data.
@@ -10,7 +10,7 @@ clients and the serial port. It doesn't support any control sequences, telnet op
 Usage
 -----
 
-Requires ESPHome v1.18.0 or higher.
+Requires ESPHome v2022.3.0 or newer.
 
 ```yaml
 external_components:
@@ -30,3 +30,23 @@ stream_server:
    uart_id: uart_bus
    port: 1234
 ```
+
+Sensors
+-------
+The server provides a binary sensor that signals whether there currently is a client connected:
+
+```yaml
+binary_sensor:
+  - platform: stream_server
+    connected:
+      name: Connected
+```
+
+It also provides a numeric sensor that indicates the number of connected clients:
+
+```yaml
+sensor:
+  - platform: stream_server
+    connection_count:
+      name: Number of connections
+```

+ 33 - 20
components/stream_server/__init__.py

@@ -1,4 +1,4 @@
-# Copyright (C) 2021 Oxan van Leeuwen
+# Copyright (C) 2021-2023 Oxan van Leeuwen
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -16,33 +16,46 @@
 import esphome.codegen as cg
 import esphome.config_validation as cv
 from esphome.components import uart
-from esphome.const import CONF_ID, CONF_PORT
+from esphome.const import CONF_ID, CONF_PORT, CONF_BUFFER_SIZE
 
 # ESPHome doesn't know the Stream abstraction yet, so hardcode to use a UART for now.
 
-AUTO_LOAD = ["async_tcp"]
+AUTO_LOAD = ["socket"]
 
 DEPENDENCIES = ["uart", "network"]
 
 MULTI_CONF = True
 
-StreamServerComponent = cg.global_ns.class_("StreamServerComponent", cg.Component)
-
-CONFIG_SCHEMA = (
-	cv.Schema(
-		{
-			cv.GenerateID(): cv.declare_id(StreamServerComponent),
-			cv.Optional(CONF_PORT): cv.port,
-		}
-	)
-		.extend(cv.COMPONENT_SCHEMA)
-		.extend(uart.UART_DEVICE_SCHEMA)
+ns = cg.global_ns
+StreamServerComponent = ns.class_("StreamServerComponent", cg.Component)
+
+
+def validate_buffer_size(buffer_size):
+    if buffer_size & (buffer_size - 1) != 0:
+        raise cv.Invalid("Buffer size must be a power of two.")
+    return buffer_size
+
+
+CONFIG_SCHEMA = cv.All(
+    cv.require_esphome_version(2022, 3, 0),
+    cv.Schema(
+        {
+            cv.GenerateID(): cv.declare_id(StreamServerComponent),
+            cv.Optional(CONF_PORT, default=6638): cv.port,
+            cv.Optional(CONF_BUFFER_SIZE, default=128): cv.All(
+                cv.positive_int, validate_buffer_size
+            ),
+        }
+    )
+    .extend(cv.COMPONENT_SCHEMA)
+    .extend(uart.UART_DEVICE_SCHEMA),
 )
 
-def to_code(config):
-	var = cg.new_Pvariable(config[CONF_ID])
-	if CONF_PORT in config:
-		cg.add(var.set_port(config[CONF_PORT]))
 
-	yield cg.register_component(var, config)
-	yield uart.register_uart_device(var, config)
+async def to_code(config):
+    var = cg.new_Pvariable(config[CONF_ID])
+    cg.add(var.set_port(config[CONF_PORT]))
+    cg.add(var.set_buffer_size(config[CONF_BUFFER_SIZE]))
+
+    await cg.register_component(var, config)
+    await uart.register_uart_device(var, config)

+ 43 - 0
components/stream_server/binary_sensor.py

@@ -0,0 +1,43 @@
+# Copyright (C) 2023 Oxan van Leeuwen
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+import esphome.codegen as cg
+import esphome.config_validation as cv
+from esphome.components import binary_sensor
+from esphome.const import (
+    DEVICE_CLASS_CONNECTIVITY,
+    ENTITY_CATEGORY_DIAGNOSTIC,
+)
+from . import ns, StreamServerComponent
+
+CONF_CONNECTED = "connected"
+CONF_STREAM_SERVER = "stream_server"
+
+CONFIG_SCHEMA = cv.Schema(
+    {
+        cv.GenerateID(CONF_STREAM_SERVER): cv.use_id(StreamServerComponent),
+        cv.Required(CONF_CONNECTED): binary_sensor.binary_sensor_schema(
+            device_class=DEVICE_CLASS_CONNECTIVITY,
+            entity_category=ENTITY_CATEGORY_DIAGNOSTIC,
+        ),
+    }
+)
+
+
+async def to_code(config):
+    server = await cg.get_variable(config[CONF_STREAM_SERVER])
+
+    sens = await binary_sensor.new_binary_sensor(config[CONF_CONNECTED])
+    cg.add(server.set_connected_sensor(sens))

+ 44 - 0
components/stream_server/sensor.py

@@ -0,0 +1,44 @@
+# Copyright (C) 2023 Oxan van Leeuwen
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+import esphome.codegen as cg
+import esphome.config_validation as cv
+from esphome.components import sensor
+from esphome.const import (
+    STATE_CLASS_MEASUREMENT,
+    ENTITY_CATEGORY_DIAGNOSTIC,
+)
+from . import ns, StreamServerComponent
+
+CONF_CONNECTION_COUNT = "connection_count"
+CONF_STREAM_SERVER = "stream_server"
+
+CONFIG_SCHEMA = cv.Schema(
+    {
+        cv.GenerateID(CONF_STREAM_SERVER): cv.use_id(StreamServerComponent),
+        cv.Required(CONF_CONNECTION_COUNT): sensor.sensor_schema(
+            accuracy_decimals=0,
+            state_class=STATE_CLASS_MEASUREMENT,
+            entity_category=ENTITY_CATEGORY_DIAGNOSTIC,
+        ),
+    }
+)
+
+
+async def to_code(config):
+    server = await cg.get_variable(config[CONF_STREAM_SERVER])
+
+    sens = await sensor.new_sensor(config[CONF_CONNECTION_COUNT])
+    cg.add(server.set_connection_count_sensor(sens))

+ 115 - 72
components/stream_server/stream_server.cpp

@@ -1,4 +1,4 @@
-/* Copyright (C) 2020-2021 Oxan van Leeuwen
+/* Copyright (C) 2020-2023 Oxan van Leeuwen
  *
  * This program is free software: you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
@@ -16,108 +16,151 @@
 
 #include "stream_server.h"
 
+#include "esphome/core/helpers.h"
 #include "esphome/core/log.h"
 #include "esphome/core/util.h"
 
-#if ESPHOME_VERSION_CODE >= VERSION_CODE(2021, 10, 0)
 #include "esphome/components/network/util.h"
-#endif
-
+#include "esphome/components/socket/socket.h"
 
-static const char *TAG = "streamserver";
+static const char *TAG = "stream_server";
 
 using namespace esphome;
 
 void StreamServerComponent::setup() {
     ESP_LOGCONFIG(TAG, "Setting up stream server...");
-    this->recv_buf_.reserve(128);
 
-    this->server_ = AsyncServer(this->port_);
-    this->server_.begin();
-    this->server_.onClient([this](void *h, AsyncClient *tcpClient) {
-        if(tcpClient == nullptr)
-            return;
+    // The make_unique() wrapper doesn't like arrays, so initialize the unique_ptr directly.
+    this->buf_ = std::unique_ptr<uint8_t[]>{new uint8_t[this->buf_size_]};
+
+    struct sockaddr_storage bind_addr;
+    socklen_t bind_addrlen = socket::set_sockaddr_any(reinterpret_cast<struct sockaddr *>(&bind_addr), sizeof(bind_addr), htons(this->port_));
 
-        this->clients_.push_back(std::unique_ptr<Client>(new Client(tcpClient, this->recv_buf_)));
-    }, this);
+    this->socket_ = socket::socket_ip(SOCK_STREAM, PF_INET);
+    this->socket_->setblocking(false);
+    this->socket_->bind(reinterpret_cast<struct sockaddr *>(&bind_addr), bind_addrlen);
+    this->socket_->listen(8);
 }
 
 void StreamServerComponent::loop() {
-    this->cleanup();
+    this->accept();
     this->read();
+    this->flush();
     this->write();
+    this->cleanup();
 }
 
-void StreamServerComponent::cleanup() {
-    auto discriminator = [](std::unique_ptr<Client> &client) { return !client->disconnected; };
-    auto last_client = std::partition(this->clients_.begin(), this->clients_.end(), discriminator);
-    for (auto it = last_client; it != this->clients_.end(); it++)
-        ESP_LOGD(TAG, "Client %s disconnected", (*it)->identifier.c_str());
-
-    this->clients_.erase(last_client, this->clients_.end());
-}
-
-void StreamServerComponent::read() {
-    int len;
-    while ((len = this->stream_->available()) > 0) {
-        char buf[128];
-        len = std::min(len, 128);
-#if ESPHOME_VERSION_CODE >= VERSION_CODE(2021, 10, 0)
-        this->stream_->read_array(reinterpret_cast<uint8_t*>(buf), len);
-#else
-        this->stream_->readBytes(buf, len);
+void StreamServerComponent::dump_config() {
+    ESP_LOGCONFIG(TAG, "Stream Server:");
+    ESP_LOGCONFIG(TAG, "  Address: %s:%u", esphome::network::get_use_address().c_str(), this->port_);
+#ifdef USE_BINARY_SENSOR
+    LOG_BINARY_SENSOR("  ", "Connected:", this->connected_sensor_);
+#endif
+#ifdef USE_SENSOR
+    LOG_SENSOR("  ", "Connection count:", this->connection_count_sensor_);
 #endif
-        for (auto const& client : this->clients_)
-            client->tcp_client->write(buf, len);
-    }
 }
 
-void StreamServerComponent::write() {
-#if ESPHOME_VERSION_CODE >= VERSION_CODE(2021, 10, 0)
-    this->stream_->write_array(this->recv_buf_);
-    this->recv_buf_.clear();
-#else
-    size_t len;
-    while ((len = this->recv_buf_.size()) > 0) {
-        this->stream_->write(this->recv_buf_.data(), len);
-        this->recv_buf_.erase(this->recv_buf_.begin(), this->recv_buf_.begin() + len);
-    }
-#endif
+void StreamServerComponent::on_shutdown() {
+    for (const Client &client : this->clients_)
+        client.socket->shutdown(SHUT_RDWR);
 }
 
-void StreamServerComponent::dump_config() {
-    ESP_LOGCONFIG(TAG, "Stream Server:");
-    ESP_LOGCONFIG(TAG, "  Address: %s:%u",
-#if ESPHOME_VERSION_CODE >= VERSION_CODE(2021, 10, 0)
-                  esphome::network::get_ip_address().str().c_str(),
-#else
-                  network_get_address().c_str(),
+void StreamServerComponent::publish_sensor() {
+#ifdef USE_BINARY_SENSOR
+    if (this->connected_sensor_)
+        this->connected_sensor_->publish_state(this->clients_.size() > 0);
+#endif
+#ifdef USE_SENSOR
+    if (this->connection_count_sensor_)
+        this->connection_count_sensor_->publish_state(this->clients_.size());
 #endif
-                  this->port_);
 }
 
-void StreamServerComponent::on_shutdown() {
-    for (auto &client : this->clients_)
-        client->tcp_client->close(true);
+void StreamServerComponent::accept() {
+    struct sockaddr_storage client_addr;
+    socklen_t client_addrlen = sizeof(client_addr);
+    std::unique_ptr<socket::Socket> socket = this->socket_->accept(reinterpret_cast<struct sockaddr *>(&client_addr), &client_addrlen);
+    if (!socket)
+        return;
+
+    socket->setblocking(false);
+    std::string identifier = socket->getpeername();
+    this->clients_.emplace_back(std::move(socket), identifier, this->buf_head_);
+    ESP_LOGD(TAG, "New client connected from %s", identifier.c_str());
+    this->publish_sensor();
 }
 
-StreamServerComponent::Client::Client(AsyncClient *client, std::vector<uint8_t> &recv_buf) :
-        tcp_client{client}, identifier{client->remoteIP().toString().c_str()}, disconnected{false} {
-    ESP_LOGD(TAG, "New client connected from %s", this->identifier.c_str());
-
-    this->tcp_client->onError(     [this](void *h, AsyncClient *client, int8_t error)  { this->disconnected = true; });
-    this->tcp_client->onDisconnect([this](void *h, AsyncClient *client)                { this->disconnected = true; });
-    this->tcp_client->onTimeout(   [this](void *h, AsyncClient *client, uint32_t time) { this->disconnected = true; });
+void StreamServerComponent::cleanup() {
+    auto discriminator = [](const Client &client) { return !client.disconnected; };
+    auto last_client = std::partition(this->clients_.begin(), this->clients_.end(), discriminator);
+    if (last_client != this->clients_.end()) {
+        this->clients_.erase(last_client, this->clients_.end());
+        this->publish_sensor();
+    }
+}
 
-    this->tcp_client->onData([&](void *h, AsyncClient *client, void *data, size_t len) {
-        if (len == 0 || data == nullptr)
-            return;
+void StreamServerComponent::read() {
+    size_t len = 0;
+    int available;
+    while ((available = this->stream_->available()) > 0) {
+        size_t free = this->buf_size_ - (this->buf_head_ - this->buf_tail_);
+        if (free == 0) {
+            // Only overwrite if nothing has been added yet, otherwise give flush() a chance to empty the buffer first.
+            if (len > 0)
+                return;
+
+            ESP_LOGE(TAG, "Incoming bytes available, but outgoing buffer is full: stream will be corrupted!");
+            free = std::min<size_t>(available, this->buf_size_);
+            this->buf_tail_ += free;
+            for (Client &client : this->clients_) {
+                if (client.position < this->buf_tail_) {
+                    ESP_LOGW(TAG, "Dropped %u pending bytes for client %s", this->buf_tail_ - client.position, client.identifier.c_str());
+                    client.position = this->buf_tail_;
+                }
+            }
+
+        }
+
+        // Fill all available contiguous space in the ring buffer.
+        len = std::min<size_t>(available, std::min<size_t>(this->buf_ahead(this->buf_head_), free));
+        this->stream_->read_array(&this->buf_[this->buf_index(this->buf_head_)], len);
+        this->buf_head_ += len;
+    }
+}
 
-        auto buf = static_cast<uint8_t *>(data);
-        recv_buf.insert(recv_buf.end(), buf, buf + len);
-    }, nullptr);
+void StreamServerComponent::flush() {
+    this->buf_tail_ = this->buf_head_;
+    for (Client &client : this->clients_) {
+        if (client.position == this->buf_head_)
+            continue;
+
+        // Split the write into two parts: from the current position to the end of the ring buffer, and from the start
+        // of the ring buffer until the head. The second part might be zero if no wraparound is necessary.
+        struct iovec iov[2];
+        iov[0].iov_base = &this->buf_[this->buf_index(client.position)];
+        iov[0].iov_len = std::min(this->buf_head_ - client.position, this->buf_ahead(client.position));
+        iov[1].iov_base = &this->buf_[0];
+        iov[1].iov_len = this->buf_head_ - (client.position + iov[0].iov_len);
+        client.position += client.socket->writev(iov, 2);
+        this->buf_tail_ = std::min(this->buf_tail_, client.position);
+    }
 }
 
-StreamServerComponent::Client::~Client() {
-    delete this->tcp_client;
+void StreamServerComponent::write() {
+    uint8_t buf[128];
+    ssize_t len;
+    for (Client &client : this->clients_) {
+        while ((len = client.socket->read(&buf, sizeof(buf))) > 0)
+            this->stream_->write_array(buf, len);
+
+        if (len == 0) {
+            ESP_LOGD(TAG, "Client %s disconnected", client.identifier.c_str());
+            client.disconnected = true;
+            continue;
+        }
+    }
 }
+
+StreamServerComponent::Client::Client(std::unique_ptr<esphome::socket::Socket> socket, std::string identifier, size_t position)
+    : socket(std::move(socket)), identifier{identifier}, position{position} {}

+ 44 - 29
components/stream_server/stream_server.h

@@ -1,4 +1,4 @@
-/* Copyright (C) 2020-2021 Oxan van Leeuwen
+/* Copyright (C) 2020-2023 Oxan van Leeuwen
  *
  * This program is free software: you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
@@ -16,39 +16,34 @@
 
 #pragma once
 
-#include "esphome/core/version.h"
 #include "esphome/core/component.h"
+#include "esphome/components/socket/socket.h"
 #include "esphome/components/uart/uart.h"
 
-// Provide VERSION_CODE for ESPHome versions lacking it, as existence checking doesn't work for function-like macros
-#ifndef VERSION_CODE
-#define VERSION_CODE(major, minor, patch) ((major) << 16 | (minor) << 8 | (patch))
+#ifdef USE_BINARY_SENSOR
+#include "esphome/components/binary_sensor/binary_sensor.h"
+#endif
+#ifdef USE_BINARY_SENSOR
+#include "esphome/components/sensor/sensor.h"
 #endif
 
 #include <memory>
 #include <string>
 #include <vector>
-#include <Stream.h>
-
-#ifdef ARDUINO_ARCH_ESP8266
-#include <ESPAsyncTCP.h>
-#else
-// AsyncTCP.h includes parts of freertos, which require FreeRTOS.h header to be included first
-#include <freertos/FreeRTOS.h>
-#include <AsyncTCP.h>
-#endif
-
-#if ESPHOME_VERSION_CODE >= VERSION_CODE(2021, 10, 0)
-using SSStream = esphome::uart::UARTComponent;
-#else
-using SSStream = Stream;
-#endif
 
 class StreamServerComponent : public esphome::Component {
 public:
     StreamServerComponent() = default;
-    explicit StreamServerComponent(SSStream *stream) : stream_{stream} {}
+    explicit StreamServerComponent(esphome::uart::UARTComponent *stream) : stream_{stream} {}
     void set_uart_parent(esphome::uart::UARTComponent *parent) { this->stream_ = parent; }
+    void set_buffer_size(size_t size) { this->buf_size_ = size; }
+
+#ifdef USE_BINARY_SENSOR
+    void set_connected_sensor(esphome::binary_sensor::BinarySensor *connected) { this->connected_sensor_ = connected; }
+#endif
+#ifdef USE_SENSOR
+    void set_connection_count_sensor(esphome::sensor::Sensor *connection_count) { this->connection_count_sensor_ = connection_count; }
+#endif
 
     void setup() override;
     void loop() override;
@@ -60,22 +55,42 @@ public:
     void set_port(uint16_t port) { this->port_ = port; }
 
 protected:
+    void publish_sensor();
+
+    void accept();
     void cleanup();
     void read();
+    void flush();
     void write();
 
+    size_t buf_index(size_t pos) { return pos & (this->buf_size_ - 1); }
+    /// Return the number of consecutive elements that are ahead of @p pos in memory.
+    size_t buf_ahead(size_t pos) { return (pos | (this->buf_size_ - 1)) - pos + 1; }
+
     struct Client {
-        Client(AsyncClient *client, std::vector<uint8_t> &recv_buf);
-        ~Client();
+        Client(std::unique_ptr<esphome::socket::Socket> socket, std::string identifier, size_t position);
 
-        AsyncClient *tcp_client{nullptr};
+        std::unique_ptr<esphome::socket::Socket> socket{nullptr};
         std::string identifier{};
         bool disconnected{false};
+        size_t position{0};
     };
 
-    SSStream *stream_{nullptr};
-    AsyncServer server_{0};
-    uint16_t port_{6638};
-    std::vector<uint8_t> recv_buf_{};
-    std::vector<std::unique_ptr<Client>> clients_{};
+    esphome::uart::UARTComponent *stream_{nullptr};
+    uint16_t port_;
+    size_t buf_size_;
+
+#ifdef USE_BINARY_SENSOR
+    esphome::binary_sensor::BinarySensor *connected_sensor_;
+#endif
+#ifdef USE_SENSOR
+    esphome::sensor::Sensor *connection_count_sensor_;
+#endif
+
+    std::unique_ptr<uint8_t[]> buf_{};
+    size_t buf_head_{0};
+    size_t buf_tail_{0};
+
+    std::unique_ptr<esphome::socket::Socket> socket_{};
+    std::vector<Client> clients_{};
 };