Program Listing for File create_timer_mock.hpp

Return to documentation for file (include/rtest/create_timer_mock.hpp)

// Copyright 2024 Beam Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// @file      create_timer_mock.hpp
// @author    Sławomir Cielepak (slawomir.cielepak@gmail.com)
// @date      2024-11-26
//
// @brief     Mock header for timer creation functionality.

#pragma once

#include <gmock/gmock.h>
#include <rtest/static_registry.hpp>

#include <chrono>
#include <exception>
#include <memory>
#include <string>
#include <utility>

#include "rclcpp/duration.hpp"
#include "rclcpp/node_interfaces/get_node_base_interface.hpp"
#include "rclcpp/node_interfaces/get_node_clock_interface.hpp"
#include "rclcpp/node_interfaces/get_node_timers_interface.hpp"
#include "rclcpp/node_interfaces/node_base_interface.hpp"
#include "rclcpp/node_interfaces/node_clock_interface.hpp"
#include "rclcpp/node_interfaces/node_timers_interface.hpp"

namespace rclcpp
{
namespace detail
{

template <typename DurationRepT, typename DurationT>
std::chrono::nanoseconds safe_cast_to_period_in_ns(
  std::chrono::duration<DurationRepT, DurationT> period)
{
  if (period < std::chrono::duration<DurationRepT, DurationT>::zero()) {
    throw std::invalid_argument{"timer period cannot be negative"};
  }

  constexpr auto maximum_safe_cast_ns =
    std::chrono::nanoseconds::max() - std::chrono::duration<DurationRepT, DurationT>(1);

  constexpr auto ns_max_as_double =
    std::chrono::duration_cast<std::chrono::duration<double, std::chrono::nanoseconds::period>>(
      maximum_safe_cast_ns);
  if (period > ns_max_as_double) {
    throw std::invalid_argument{"timer period must be less than std::chrono::nanoseconds::max()"};
  }

  const auto period_ns = std::chrono::duration_cast<std::chrono::nanoseconds>(period);
  if (period_ns < std::chrono::nanoseconds::zero()) {
    throw std::runtime_error{"Casting timer period to nanoseconds resulted in integer overflow."};
  }

  return period_ns;
}
}  // namespace detail

template <typename CallbackT>
typename rclcpp::TimerBase::SharedPtr create_timer(
  std::shared_ptr<node_interfaces::NodeBaseInterface> node_base,
  std::shared_ptr<node_interfaces::NodeTimersInterface> node_timers,
  rclcpp::Clock::SharedPtr clock,
  rclcpp::Duration period,
  CallbackT && callback,
  rclcpp::CallbackGroup::SharedPtr group = nullptr,
  bool autostart = true)
{
  return create_timer(
    clock,
    period.to_chrono<std::chrono::nanoseconds>(),
    std::forward<CallbackT>(callback),
    group,
    node_base.get(),
    node_timers.get(),
    autostart);
}

template <typename NodeT, typename CallbackT>
typename rclcpp::TimerBase::SharedPtr create_timer(
  NodeT node,
  rclcpp::Clock::SharedPtr clock,
  rclcpp::Duration period,
  CallbackT && callback,
  rclcpp::CallbackGroup::SharedPtr group = nullptr,
  bool autostart = true)
{
  return create_timer(
    clock,
    period.to_chrono<std::chrono::nanoseconds>(),
    std::forward<CallbackT>(callback),
    group,
    rclcpp::node_interfaces::get_node_base_interface(node).get(),
    rclcpp::node_interfaces::get_node_timers_interface(node).get(),
    autostart);
}

template <typename DurationRepT, typename DurationT, typename CallbackT>
typename rclcpp::GenericTimer<CallbackT>::SharedPtr create_timer(
  rclcpp::Clock::SharedPtr clock,
  std::chrono::duration<DurationRepT, DurationT> period,
  CallbackT callback,
  rclcpp::CallbackGroup::SharedPtr group,
  node_interfaces::NodeBaseInterface * node_base,
  node_interfaces::NodeTimersInterface * node_timers,
  bool autostart = true)
{
  if (clock == nullptr) {
    throw std::invalid_argument{"clock cannot be null"};
  }
  if (node_base == nullptr) {
    throw std::invalid_argument{"input node_base cannot be null"};
  }
  if (node_timers == nullptr) {
    throw std::invalid_argument{"input node_timers cannot be null"};
  }

  const std::chrono::nanoseconds period_ns = detail::safe_cast_to_period_in_ns(period);

  auto timer = rclcpp::GenericTimer<CallbackT>::make_shared(
    std::move(clock), period_ns, std::move(callback), node_base->get_context(), autostart);
  node_timers->add_timer(timer, group);

  rtest::StaticMocksRegistry::instance().registerTimer(
    node_base->get_fully_qualified_name(), timer);

  return timer;
}

template <typename DurationRepT, typename DurationT, typename CallbackT>
typename rclcpp::WallTimer<CallbackT>::SharedPtr create_wall_timer(
  std::chrono::duration<DurationRepT, DurationT> period,
  CallbackT callback,
  rclcpp::CallbackGroup::SharedPtr group,
  node_interfaces::NodeBaseInterface * node_base,
  node_interfaces::NodeTimersInterface * node_timers,
  bool autostart = true)
{
  if (node_base == nullptr) {
    throw std::invalid_argument{"input node_base cannot be null"};
  }

  if (node_timers == nullptr) {
    throw std::invalid_argument{"input node_timers cannot be null"};
  }

  const std::chrono::nanoseconds period_ns = detail::safe_cast_to_period_in_ns(period);

  auto timer = rclcpp::WallTimer<CallbackT>::make_shared(
    period_ns, std::move(callback), node_base->get_context(), autostart);
  node_timers->add_timer(timer, group);

  rtest::StaticMocksRegistry::instance().registerTimer(
    node_base->get_fully_qualified_name(), timer);

  return timer;
}
}  // namespace rclcpp

namespace rtest
{

static inline std::vector<std::shared_ptr<rclcpp::TimerBase>> findTimers(
  const std::string & fullyQualifiedNodeName)
{
  std::vector<std::shared_ptr<rclcpp::TimerBase>> timers{};
  for (auto & weakPtr : StaticMocksRegistry::instance().getTimers(fullyQualifiedNodeName)) {
    if (auto timer = weakPtr.lock()) {
      timers.push_back(timer);
    }
  }
  return timers;
}

template <typename NodeT>
static inline std::vector<std::shared_ptr<rclcpp::TimerBase>> findTimers(
  const std::shared_ptr<NodeT> nodePtr)
{
  return findTimers(nodePtr->get_fully_qualified_name());
}

}  // namespace rtest