diff --git a/include/ranges/enumerate.hpp b/include/ranges/enumerate.hpp index 2f9f478..4cd558b 100644 --- a/include/ranges/enumerate.hpp +++ b/include/ranges/enumerate.hpp @@ -1,66 +1,101 @@ #pragma once #include "types.hpp" +#include #include - +#include namespace util { namespace rg = std::ranges; -template struct enumerate_iterator_t : rg::iterator_t { +template struct enumerate_iterator_t { using base = rg::iterator_t; using enumerator_type = std::size_t; using value_type = std::pair>; + using difference_type = std::ptrdiff_t; using reference = std::pair>>; + base base_iter{}; enumerator_type enumerator{}; enumerate_iterator_t() = default; - enumerate_iterator_t(const base &b) : base{b}, enumerator(0) {} + enumerate_iterator_t(const base &b) : base_iter{b}, enumerator(0) {} + enumerate_iterator_t(base &&b) : base_iter{std::move(b)}, enumerator(0) {} - auto operator<=>(const enumerate_iterator_t &other) const { - return static_cast(*this) <=> - static_cast(other); + auto constexpr operator==(const enumerate_iterator_t &other) const -> bool { + return base_iter == other.base_iter; + } + + template S> + auto constexpr operator==(const S &sentinel) const -> bool { + return sentinel == base_iter; + } + + auto constexpr operator<=>(const enumerate_iterator_t &other) const { + return base_iter <=> other.base_iter; } auto inc_enumerator() -> void { enumerator++; } + auto dec_enumerator() -> void { enumerator--; } + auto dec_enumerator(const difference_type &n) -> void { enumerator -= n; } auto operator++(int) -> enumerate_iterator_t { - const auto result = static_cast(*this)++; + const auto result = *this; + base_iter++; inc_enumerator(); return result; } auto operator++() -> enumerate_iterator_t & { inc_enumerator(); - ++static_cast(*this); + ++base_iter; return (*this); } - auto operator*() -> reference { - return reference(enumerator, *static_cast(*this)); + auto operator--(int) -> enumerate_iterator_t + requires(std::bidirectional_iterator) { + const auto result = *this; + base_iter--; + dec_enumerator(); + return result; } + + auto operator--() + -> enumerate_iterator_t &requires(std::bidirectional_iterator) { + dec_enumerator(); + --base_iter; + return (*this); + } + + auto constexpr operator+(const difference_type &n) const + -> enumerate_iterator_t { + auto ret = *this; + ret.inc_enumerator(n); + base_iter += n; + + return ret; + } + + auto operator*() -> reference { return reference(enumerator, *base_iter); } auto operator*() const -> reference { - return reference(enumerator, *static_cast(*this)); + return reference(enumerator, *base_iter); } }; template class enumerate_view : public rg::view_interface> { - R base_{}; isize enumerator{}; - enumerate_iterator_t iter_{std::begin(base_)}; + enumerate_iterator_t iter_; + enumerate_iterator_t end_; public: enumerate_view() = default; - constexpr enumerate_view(R base) : base_(base), iter_(std::begin(base_)) {} - - constexpr R base() const & { return base_; } - constexpr R base() && { return std::move(base_); } + constexpr enumerate_view(R base) + : iter_(std::begin(base)), end_(std::end(base)) {} constexpr auto begin() const { return iter_; } - constexpr auto end() const { return std::end(base_); } + constexpr auto end() const { return end_; } }; namespace detail {