Program Listing for File static_registry.hpp

Return to documentation for file (include/rtest/static_registry.hpp)

// Copyright 2024 Beam Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// @file      static_registry.hpp
// @author    Sławomir Cielepak (slawomir.cielepak@gmail.com)
// @date      2024-11-26
//
// @brief     Mock header for ROS 2 static registry.

#pragma once

#include <algorithm>
#include <map>
#include <vector>
#include <memory>
#include <iostream>
#include <typeinfo>
#include <functional>
#include <mutex>

#include <boost/type_index.hpp>

#include <rtest/single_instance.hpp>
#include <rtest/registry_cleaner.hpp>

namespace rclcpp
{
class PublisherBase;
class SubscriptionBase;
class TimerBase;
class ServiceBase;
class ClientBase;
}  // namespace rclcpp

namespace rclcpp_action
{
class ServerBase;
class ClientBase;
}  // namespace rclcpp_action

namespace rtest
{

class MockBase
{
};

class StaticMocksRegistry : SingleInstance<StaticMocksRegistry>
{
public:
  struct LazyInitEntry
  {
    void * raw_ptr;
    std::string node_name;
    std::string action_name;
    std::function<void()> init_callback;
  };

public:
  using TopicNameT = std::string;
  using FullyQualifiedNodeNameT = std::string;
  using TopicToPublishersMapT = std::map<TopicNameT, std::weak_ptr<rclcpp::PublisherBase>>;
  using TopicToSubscriptionsMapT = std::map<TopicNameT, std::weak_ptr<rclcpp::SubscriptionBase>>;
  using ServiceNameT = std::string;
  using ServiceToServicesMapT = std::map<ServiceNameT, std::weak_ptr<rclcpp::ServiceBase>>;
  using ServiceToClientsMapT = std::map<ServiceNameT, std::weak_ptr<rclcpp::ClientBase>>;
  using ActionNameT = std::string;
  using ActionToServersMapT = std::map<ActionNameT, std::weak_ptr<rclcpp_action::ServerBase>>;
  using ActionToClientsMapT = std::map<ActionNameT, std::weak_ptr<rclcpp_action::ClientBase>>;

  static StaticMocksRegistry & instance() { return theRegistry_; }

  template <typename MessageT>
  void registerPublisher(
    const FullyQualifiedNodeNameT & nodeName,
    const TopicNameT & topicName,
    std::weak_ptr<rclcpp::PublisherBase> pub)
  {
    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerPublisher<"
                << boost::typeindex::type_id<MessageT>().pretty_name() << ">(\"" << nodeName
                << "\", \"" << topicName << "\")\n";
    }
    registerEntity(publishersRegistry_[nodeName], topicName, pub);
  }

  std::vector<std::weak_ptr<rclcpp::PublisherBase>> getNodePublishers(
    const FullyQualifiedNodeNameT & nodeName)
  {
    std::vector<std::weak_ptr<rclcpp::PublisherBase>> publishers{};
    for (auto [topicName, publisher] : publishersRegistry_[nodeName]) {
      publishers.push_back(publisher);
    }
    return publishers;
  }

  std::weak_ptr<rclcpp::PublisherBase> getPublisher(
    const FullyQualifiedNodeNameT & nodeName,
    const TopicNameT & topicName)
  {
    return findEntity(publishersRegistry_[nodeName], topicName);
  }

  template <typename MessageT>
  void registerSubscription(
    const FullyQualifiedNodeNameT & nodeName,
    const TopicNameT & topicName,
    std::weak_ptr<rclcpp::SubscriptionBase> sub)
  {
    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerSubscription<"
                << boost::typeindex::type_id<MessageT>().pretty_name() << ">(\"" << nodeName
                << "\", \"" << topicName << "\")\n";
    }
    registerEntity(subscriptionsRegistry_[nodeName], topicName, sub);
  }

  std::vector<std::weak_ptr<rclcpp::SubscriptionBase>> getNodeSubscriptions(
    const FullyQualifiedNodeNameT & nodeName)
  {
    std::vector<std::weak_ptr<rclcpp::SubscriptionBase>> subscriptions{};
    for (auto [topicName, subscription] : subscriptionsRegistry_[nodeName]) {
      subscriptions.push_back(subscription);
    }
    return subscriptions;
  }

  std::weak_ptr<rclcpp::SubscriptionBase> getSubscription(
    const FullyQualifiedNodeNameT & nodeName,
    const TopicNameT & topicName)
  {
    return findEntity(subscriptionsRegistry_[nodeName], topicName);
  }

  bool registerTimer(
    const FullyQualifiedNodeNameT & nodeName,
    std::weak_ptr<rclcpp::TimerBase> timer)
  {
    timersRegistry_[nodeName].push_back(timer);
    return true;
  }

  std::vector<std::weak_ptr<rclcpp::TimerBase>> getTimers(const FullyQualifiedNodeNameT & nodeName)
  {
    return findEntity(timersRegistry_, nodeName);
  }

  void enableVerboseLogs(bool on) { verbose_ = on; }

  std::weak_ptr<MockBase> getMock(void * ptr)
  {
    auto it = mockRegistry_.find(ptr);
    if (it != mockRegistry_.end()) {
      return it->second;
    }
    return {};
  }

  void attachMock(void * ptr, std::weak_ptr<MockBase> mock) { mockRegistry_[ptr] = mock; }

  void detachMock(void * ptr)
  {
    auto it = mockRegistry_.find(ptr);
    if (it != mockRegistry_.end()) {
      mockRegistry_.erase(it);
    }
  }

  template <typename ServiceT>
  void registerService(
    const FullyQualifiedNodeNameT & nodeName,
    const ServiceNameT & serviceName,
    std::weak_ptr<rclcpp::ServiceBase> service)
  {
    const char * namePtr = serviceName.c_str();
    if (!serviceName.empty() && serviceName[0] == '/') {
      namePtr++;
    }

    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerService<"
                << boost::typeindex::type_id<ServiceT>().pretty_name() << ">(\"" << nodeName
                << "\", \"" << serviceName << "\")\n";
    }
    registerEntity(servicesRegistry_[nodeName], namePtr, service);
  }

  std::vector<std::weak_ptr<rclcpp::ServiceBase>> getNodeServices(
    const FullyQualifiedNodeNameT & nodeName)
  {
    std::vector<std::weak_ptr<rclcpp::ServiceBase>> services{};
    for (auto [serviceName, service] : servicesRegistry_[nodeName]) {
      services.push_back(service);
    }
    return services;
  }

  std::weak_ptr<rclcpp::ServiceBase> getService(
    const FullyQualifiedNodeNameT & nodeName,
    const ServiceNameT & serviceName)
  {
    return findEntity(servicesRegistry_[nodeName], serviceName);
  }

  template <typename ServiceT>
  void registerServiceClient(
    const FullyQualifiedNodeNameT & nodeName,
    const ServiceNameT & serviceName,
    std::weak_ptr<rclcpp::ClientBase> client)
  {
    const char * namePtr = serviceName.c_str();
    if (!serviceName.empty() && serviceName[0] == '/') {
      namePtr++;
    }

    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerServiceClient<"
                << boost::typeindex::type_id<ServiceT>().pretty_name() << ">(\"" << nodeName
                << "\", \"" << serviceName << "\")\n";
    }
    registerEntity(serviceClientsRegistry_[nodeName], namePtr, client);
  }

  std::vector<std::weak_ptr<rclcpp::ClientBase>> getNodeServiceClients(
    const FullyQualifiedNodeNameT & nodeName)
  {
    std::vector<std::weak_ptr<rclcpp::ClientBase>> clients{};
    for (auto [serviceName, client] : serviceClientsRegistry_[nodeName]) {
      clients.push_back(client);
    }
    return clients;
  }

  std::weak_ptr<rclcpp::ClientBase> getServiceClient(
    const FullyQualifiedNodeNameT & nodeName,
    const ServiceNameT & serviceName)
  {
    return findEntity(serviceClientsRegistry_[nodeName], serviceName);
  }

  template <typename ActionT>
  void registerActionServer(
    const FullyQualifiedNodeNameT & nodeName,
    const ActionNameT & actionName,
    std::weak_ptr<rclcpp_action::ServerBase> server)
  {
    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerActionServer<"
                << boost::typeindex::type_id<ActionT>().pretty_name() << ">(\"" << nodeName
                << "\", \"" << actionName << "\")\n";
    }
    registerEntity(actionServersRegistry_[nodeName], actionName, server);
  }

  template <typename ActionT>
  void registerActionClient(
    const FullyQualifiedNodeNameT & nodeName,
    const ActionNameT & actionName,
    std::weak_ptr<rclcpp_action::ClientBase> client)
  {
    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerActionClient<"
                << boost::typeindex::type_id<ActionT>().pretty_name() << ">(\"" << nodeName
                << "\", \"" << actionName << "\")\n";
    }
    registerEntity(actionClientsRegistry_[nodeName], actionName, client);
  }

  std::weak_ptr<rclcpp_action::ServerBase> getActionServer(
    const FullyQualifiedNodeNameT & nodeName,
    const ActionNameT & actionName)
  {
    tryLazyInit(lazy_init_action_servers_);
    return findEntity(actionServersRegistry_[nodeName], actionName);
  }

  void tryLazyInit(std::vector<LazyInitEntry> & lazyInitVector)
  {
    std::lock_guard<std::mutex> lock(lazy_init_mutex_);

    for (auto it = lazyInitVector.begin(); it != lazyInitVector.end();) {
      try {
        it->init_callback();
        it = lazyInitVector.erase(it);
      } catch (const std::exception & e) {
        ++it;
      }
    }
  }

  std::weak_ptr<rclcpp_action::ClientBase> getActionClient(
    const FullyQualifiedNodeNameT & nodeName,
    const ActionNameT & actionName)
  {
    tryLazyInit(lazy_init_action_clients_);
    return findEntity(actionClientsRegistry_[nodeName], actionName);
  }

  void registerLazyInitClient(
    void * raw_ptr,
    const std::string & node_name,
    const std::string & action_name,
    std::function<void()> callback)
  {
    std::lock_guard<std::mutex> lock(lazy_init_mutex_);
    lazy_init_action_clients_.push_back({raw_ptr, node_name, action_name, std::move(callback)});

    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerLazyInitClient - " << "Node: '" << node_name
                << "', Action: '" << action_name << "'" << std::endl;
    }
  }

  void removeLazyInitClient(void * raw_ptr)
  {
    std::lock_guard<std::mutex> lock(lazy_init_mutex_);
    lazy_init_action_clients_.erase(
      std::remove_if(
        lazy_init_action_clients_.begin(),
        lazy_init_action_clients_.end(),
        [raw_ptr](const LazyInitEntry & entry) { return entry.raw_ptr == raw_ptr; }),
      lazy_init_action_clients_.end());
  }

  void registerLazyInitServer(
    void * raw_ptr,
    const std::string & node_name,
    const std::string & action_name,
    std::function<void()> callback)
  {
    std::lock_guard<std::mutex> lock(lazy_init_mutex_);
    lazy_init_action_servers_.push_back({raw_ptr, node_name, action_name, std::move(callback)});

    if (verbose_) {
      std::cout << "StaticMocksRegistry::registerLazyInitServer - " << "Node: '" << node_name
                << "', Action: '" << action_name << "'" << std::endl;
    }
  }

  void removeLazyInitServer(void * raw_ptr)
  {
    std::lock_guard<std::mutex> lock(lazy_init_mutex_);
    lazy_init_action_servers_.erase(
      std::remove_if(
        lazy_init_action_servers_.begin(),
        lazy_init_action_servers_.end(),
        [raw_ptr](const LazyInitEntry & entry) { return entry.raw_ptr == raw_ptr; }),
      lazy_init_action_servers_.end());
  }

  void reset()
  {
    publishersRegistry_.clear();
    subscriptionsRegistry_.clear();
    timersRegistry_.clear();
    servicesRegistry_.clear();
    serviceClientsRegistry_.clear();
    actionServersRegistry_.clear();
    actionClientsRegistry_.clear();
    mockRegistry_.clear();
    lazy_init_action_clients_.clear();
    lazy_init_action_servers_.clear();
  }

private:
  StaticMocksRegistry()
  {
    ::testing::UnitTest::GetInstance()->listeners().Append(new MockRegistryCleaner());
  }

  static StaticMocksRegistry theRegistry_;

  template <typename RegistryT, typename EntityT>
  void registerEntity(RegistryT & reg, const TopicNameT & topicName, EntityT e)
  {
    if (!reg[topicName].lock()) {
      reg[topicName] = e;
    }
  }

  template <typename RegistryT>
  typename RegistryT::value_type::second_type findEntity(
    RegistryT & reg,
    const TopicNameT & topicName)
  {
    auto it = reg.find(topicName);
    if (it != reg.end()) {
      return it->second;
    } else {
      return {};
    }
  }

  std::map<FullyQualifiedNodeNameT, TopicToPublishersMapT> publishersRegistry_;
  std::map<FullyQualifiedNodeNameT, TopicToSubscriptionsMapT> subscriptionsRegistry_;
  std::map<FullyQualifiedNodeNameT, std::vector<std::weak_ptr<rclcpp::TimerBase>>> timersRegistry_;
  std::map<FullyQualifiedNodeNameT, ServiceToServicesMapT> servicesRegistry_;
  std::map<FullyQualifiedNodeNameT, ServiceToClientsMapT> serviceClientsRegistry_;
  std::map<FullyQualifiedNodeNameT, ActionToServersMapT> actionServersRegistry_;
  std::map<FullyQualifiedNodeNameT, ActionToClientsMapT> actionClientsRegistry_;

  std::map<void *, std::weak_ptr<MockBase>> mockRegistry_;

  std::vector<LazyInitEntry> lazy_init_action_clients_;
  std::vector<LazyInitEntry> lazy_init_action_servers_;
  std::mutex lazy_init_mutex_;

  bool verbose_{false};
};

void enableVerboseLogs(bool on);

}  // namespace rtest