#pragma once

#include <atomic>
#include <span>

#include "lock.h"

struct RingBuffer {
    std::span<std::byte> buffer;

    std::atomic<size_t> read_ptr = 0;
    std::atomic<size_t> write_ptr = 0;
    std::atomic<bool> full = 0;

    bool Store(std::span<const std::byte> data) {
        InterruptLock lock;

        if (data.size() > FreeSpace()) {
            return false;
        }
        const size_t to_copy = std::min(buffer.size() - write_ptr, data.size());
        std::copy(data.begin(), data.begin() + to_copy,
                  buffer.begin() + write_ptr);
        if (to_copy < data.size()) {
            std::copy(data.begin() + to_copy, data.end(), buffer.begin());
        }
        Push(data.size());

        return true;
    }

    bool Load(std::span<std::byte> out) {
        InterruptLock lock;

        if (out.size() > AvailableData()) {
            return false;
        }
        const size_t to_copy = std::min(buffer.size() - read_ptr, out.size());
        std::copy(buffer.begin() + read_ptr,
                  buffer.begin() + read_ptr + to_copy, out.begin());
        if (to_copy < out.size()) {
            std::copy(buffer.begin(), buffer.begin() + out.size() - to_copy,
                      out.begin() + to_copy);
        }
        Pop(out.size());
        return true;
    }

    bool Push(size_t amount) {
        InterruptLock lock;

        if (amount > FreeSpace()) {
            return false;
        }
        write_ptr = (write_ptr + amount) % buffer.size();
        if (read_ptr == write_ptr) {
            full = true;
        }
        return true;
    }

    bool Pop(size_t amount) {
        InterruptLock lock;

        if (amount > AvailableData()) {
            return false;
        }
        read_ptr = (read_ptr + amount) % buffer.size();
        if (amount > 0) {
            full = false;
        }
        return true;
    }

    size_t FreeSpace() const {
        InterruptLock lock;

        return buffer.size() - AvailableData();
    }

    size_t AvailableData() const {
        InterruptLock lock;

        if (read_ptr == write_ptr) {
            return full ? buffer.size() : 0;
        }
        return (buffer.size() + write_ptr - read_ptr) % buffer.size();
    }

    uint8_t* RawReadPointer() const {
        InterruptLock lock;

        return reinterpret_cast<uint8_t*>(buffer.data() + read_ptr);
    }

    size_t ContiguousAvailableData() const {
        InterruptLock lock;

        if (read_ptr < write_ptr) {
            return AvailableData();
        }
        if (full) {
            return 0;
        }

        return buffer.size() - read_ptr;
    }
};