#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
#  Copyright (c) 2020 Binero AB https://binero.com
#  Copyright (c) 2013 Catalyst IT http://www.catalyst.net.nz
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import argparse
import os
import re
import subprocess
import sys
import socket
import json


__version__ = '1.1.0'

CEPH_COMMAND = '/usr/bin/ceph'
CEPH_ADM_COMMAND = '/usr/sbin/cephadm'

STATUS_OK = 0
STATUS_CRITICAL = 2
STATUS_UNKNOWN = 3


def get_host_ips(hostname, conf_file=None):
    """
    Get all IP addresses for a hostname from DNS and ceph config.
    Returns a set of IP addresses.
    """
    ips = set()
    
    # Get IPs from DNS resolution
    try:
        addrinfo = socket.getaddrinfo(hostname, None, 0, socket.SOCK_STREAM)
        for info in addrinfo:
            ips.add(info[-1][0])
    except Exception:
        pass
    
    # Get IP from ceph config [mon.<hostname>] public_addr
    if conf_file and os.path.exists(conf_file):
        try:
            with open(conf_file, 'r') as f:
                in_mon_section = False
                for line in f:
                    stripped = line.strip()
                    if not stripped or stripped.startswith('#'):
                        continue
                    # Check for [mon.<hostname>] section
                    if stripped == '[mon.' + hostname + ']':
                        in_mon_section = True
                        continue
                    # Check for end of section or new section
                    if in_mon_section and stripped.startswith('['):
                        break
                    # Get public_addr from mon section
                    if in_mon_section and stripped.startswith('public_addr'):
                        key, _, value = stripped.partition('=')
                        value = value.strip()
                        ip_pattern = r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b'
                        matches = re.findall(ip_pattern, value)
                        for m in matches:
                            parts = m.split('.')
                            if all(0 <= int(p) <= 255 for p in parts):
                                ips.add(m)
                        break
        except Exception:
            pass
    
    return ips


def extract_ips_from_osd_line(line):
    """
    Extract all IP addresses from an OSD dump line.
    Returns a set of IP addresses.
    """
    ips = set()
    ip_pattern = r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b'
    matches = re.findall(ip_pattern, line)
    for match in matches:
        parts = match.split('.')
        if all(0 <= int(p) <= 255 for p in parts):
            ips.add(match)
    return ips


def find_osds_for_host(osd_dump, hostname, conf_file=None):
    """
    Find all OSDs in the dump that belong to the given hostname.
    Returns list of OSD names (e.g., ['osd.0', 'osd.1']).
    """
    host_ips = get_host_ips(hostname, conf_file)
    
    if not host_ips:
        # If we couldn't resolve the hostname, fall back to old behavior
        # Try to use the hostname directly as an IP
        host_ips = {hostname}
    
    osds = []
    
    for line in osd_dump.split('\n'):
        osd_ips = extract_ips_from_osd_line(line)
        
        # Check if any of the OSD's IPs match any of the host's IPs
        matching_ip = any(osd_ip in host_ips for osd_ip in osd_ips)
        
        # Also check if the line contains the hostname as a string
        hostname_in_line = hostname in line
        
        if not matching_ip and not hostname_in_line:
            continue
        
        # Extract OSD ID
        osd_match = re.match(r'^(osd\.\d+)', line)
        if not osd_match:
            continue
        
        # Check if OSD is up
        if 'up' in line and 'in' in line:
            osds.append(osd_match.group(1))
    
    return osds


def main():
    parser = argparse.ArgumentParser(description="'ceph osd' nagios plugin.")

    parser.add_argument('-e','--exe', help='ceph executable [%s]' % CEPH_COMMAND)
    parser.add_argument('-A','--admexe', help='cephadm executable [%s]' % CEPH_ADM_COMMAND)
    parser.add_argument('-c','--conf', help='alternative ceph conf file')
    parser.add_argument('-m','--monaddress', help='ceph monitor address[:port]')
    parser.add_argument('-i','--id', help='ceph client id')
    parser.add_argument('-k','--keyring', help='ceph client keyring file')
    parser.add_argument('-H','--host', help='osd host', required=True)
    parser.add_argument('-C','--critical', help='critical threshold', default=60)
    parser.add_argument('-V','--version', help='show version and exit', action='store_true')
    parser.add_argument('-a','--cephadm', help='uses cephadm to execute the command', action='store_true')

    args = parser.parse_args()

    if args.version:
        print('version %s' % __version__)
        return STATUS_OK

    cephadm_exec = args.admexe if args.admexe else CEPH_ADM_COMMAND
    ceph_exec = args.exe if args.exe else CEPH_COMMAND
    if args.cephadm:
        if not os.path.exists(cephadm_exec):
            print("ERROR: cephadm executable '%s' doesn't exist" % cephadm_exec)
            return STATUS_UNKNOWN
    else:
        if not os.path.exists(ceph_exec):
            print("UNKNOWN: ceph executable '%s' doesn't exist" % ceph_exec)
            return STATUS_UNKNOWN

    if args.conf and not os.path.exists(args.conf):
        print("UNKNOWN: ceph conf file '%s' doesn't exist" % args.conf)
        return STATUS_UNKNOWN

    if args.keyring and not os.path.exists(args.keyring):
        print("UNKNOWN: keyring file '%s' doesn't exist" % args.keyring)
        return STATUS_UNKNOWN

    if not args.host:
        print("UNKNOWN: no OSD hostname given")
        return STATUS_UNKNOWN

    base_ceph_cmd = [ceph_exec]

    if args.cephadm:
        # Prepend the command with the cephadm binary and the shell command
        base_ceph_cmd = [cephadm_exec, 'shell']

        if args.keyring:
            base_ceph_cmd.append('-v')
            base_ceph_cmd.append('%s:%s:ro' % (args.keyring, args.keyring))
        base_ceph_cmd.append('--')
        base_ceph_cmd.append(ceph_exec)

    if args.monaddress:
        base_ceph_cmd.append('-m')
        base_ceph_cmd.append(args.monaddress)
    if args.conf:
        base_ceph_cmd.append('-c')
        base_ceph_cmd.append(args.conf)
    if args.id:
        base_ceph_cmd.append('--id')
        base_ceph_cmd.append(args.id)
    if args.keyring:
        base_ceph_cmd.append('--keyring')
        base_ceph_cmd.append(args.keyring)

    ceph_cmd = base_ceph_cmd.copy()
    ceph_cmd.append('osd')
    ceph_cmd.append('dump')

    p = subprocess.Popen(ceph_cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
    output, err = p.communicate()

    if p.returncode != 0 or not output:
        print("CRITICAL: %s" % '\n'.join(err.decode('utf-8').split('\n')[(4 if args.cephadm else 0):]))
        return STATUS_CRITICAL

    # Use the new matching function to find OSDs for the host
    osds_up = find_osds_for_host(output.decode('utf-8'), args.host, args.conf)

    final_status = STATUS_OK
    lines = []

    for osd in osds_up:
        daemon_ceph_cmd = base_ceph_cmd.copy()
        daemon_ceph_cmd.append('--format')
        daemon_ceph_cmd.append('json')
        daemon_ceph_cmd.append('daemon')
        daemon_ceph_cmd.append(osd)
        daemon_ceph_cmd.append('perf')
        daemon_ceph_cmd.append('dump')

        p = subprocess.Popen(daemon_ceph_cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
        output, err = p.communicate()

        if p.returncode != 0 or not output:
            print("CRITICAL: %s" % '\n'.join(err.decode('utf-8').split('\n')[(4 if args.cephadm else 0):]))
            return STATUS_CRITICAL

        try:
            data = json.loads(output.decode('utf-8'))
        except Exception:
            print("CRITICAL: failed to load json")
            return STATUS_CRITICAL

        bluefs = data.get('bluefs', None)

        if not bluefs:
            continue

        db_total_bytes = bluefs.get('db_total_bytes')
        db_used_bytes = bluefs.get('db_used_bytes')
        perc = (float(db_used_bytes) / float(db_total_bytes) * 100)

        if perc >= float(args.critical) and final_status == STATUS_OK:
            final_status = STATUS_CRITICAL

        lines.append("%s=%.2f%%" % (osd, perc))

    if final_status == STATUS_OK:
        print("OK: %s" % (' '.join(lines)))
    else:
        print("CRITICAL: %s" % (' '.join(lines)))

    return final_status


if __name__ == "__main__":
    sys.exit(main())
