#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
#  Copyright (c) 2013 Catalyst IT http://www.catalyst.net.nz
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
# 1.5.2 (2019-06-16) Martin Seener: fixed regex to work with Ceph Nautilus (14.2.x)

import argparse
import os
import re
import subprocess
import sys
import socket

__version__ = '1.6.0'

# default ceph values
CEPH_ADM_COMMAND = '/usr/sbin/cephadm'
CEPH_COMMAND = '/usr/bin/ceph'

# nagios exit code
STATUS_OK = 0
STATUS_WARNING = 1
STATUS_ERROR = 2
STATUS_UNKNOWN = 3


def get_host_ips(hostname, conf_file=None):
    """
    Get all IP addresses for a hostname from DNS and ceph config.
    Returns a set of IP addresses.
    """
    ips = set()
    
    # Get IPs from DNS resolution
    try:
        addrinfo = socket.getaddrinfo(hostname, None, 0, socket.SOCK_STREAM)
        for info in addrinfo:
            ips.add(info[-1][0])
    except Exception:
        pass
    
    # Get IP from ceph config [mon.<hostname>] public_addr
    if conf_file and os.path.exists(conf_file):
        try:
            with open(conf_file, 'r') as f:
                in_mon_section = False
                for line in f:
                    stripped = line.strip()
                    if not stripped or stripped.startswith('#'):
                        continue
                    # Check for [mon.<hostname>] section
                    if stripped == '[mon.' + hostname + ']':
                        in_mon_section = True
                        continue
                    # Check for end of section or new section
                    if in_mon_section and stripped.startswith('['):
                        break
                    # Get public_addr from mon section
                    if in_mon_section and stripped.startswith('public_addr'):
                        key, _, value = stripped.partition('=')
                        value = value.strip()
                        ip_pattern = r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b'
                        matches = re.findall(ip_pattern, value)
                        for m in matches:
                            parts = m.split('.')
                            if all(0 <= int(p) <= 255 for p in parts):
                                ips.add(m)
                        break
        except Exception:
            pass
    
    return ips


def extract_ips_from_osd_line(line):
    """
    Extract all IP addresses from an OSD dump line.
    Returns a set of IP addresses.
    """
    ips = set()
    ip_pattern = r'\b(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\b'
    matches = re.findall(ip_pattern, line)
    for match in matches:
        parts = match.split('.')
        if all(0 <= int(p) <= 255 for p in parts):
            ips.add(match)
    return ips


def match_osds_for_host(osd_dump, hostname, osd_id_pattern, conf_file=None):
    """
    Find all OSDs in the dump that belong to the given hostname.
    Returns dict with 'up', 'down', 'down_in', 'down_out' lists.
    """
    host_ips = get_host_ips(hostname, conf_file)
    
    if not host_ips:
        # If we couldn't resolve the hostname, fall back to old behavior
        # Try to use the hostname directly as an IP
        host_ips = {hostname}
    
    up = []
    down = []
    down_in = []
    down_out = []
    
    osd_host = re.escape(hostname)
    
    for line in osd_dump.split('\n'):
        osd_ips = extract_ips_from_osd_line(line)
        
        # Check if any of the OSD's IPs match any of the host's IPs
        matching_ip = any(osd_ip in host_ips for osd_ip in osd_ips)
        
        # Also check if the line contains the hostname as a string (for hostnames in newer ceph versions)
        hostname_in_line = hostname in line
        
        if not matching_ip and not hostname_in_line:
            continue
        
        # Extract OSD ID
        osd_match = re.match(r'^(osd\.\d+)', line)
        if not osd_match:
            continue
        osd_id = osd_match.group(1)
        
        # Check if OSD ID matches the pattern
        if not re.match(osd_id_pattern, osd_id):
            continue
        
        if 'up' in line and 'in' in line:
            up.append(osd_id)
        elif line.startswith('osd.') and ' down ' in line:
            if ' in ' in line:
                down_in.append(osd_id)
            elif ' out ' in line:
                down_out.append(osd_id)
            down.append(osd_id)
    
    return {
        'up': up,
        'down': down,
        'down_in': down_in,
        'down_out': down_out
    }


def main():

  # parse args
  parser = argparse.ArgumentParser(description="'ceph osd' nagios plugin.")
  parser.add_argument('-e','--exe', help='ceph executable [%s]' % CEPH_COMMAND)
  parser.add_argument('-A','--admexe', help='cephadm executable [%s]' % CEPH_ADM_COMMAND)
  parser.add_argument('-c','--conf', help='alternative ceph conf file')
  parser.add_argument('-m','--monaddress', help='ceph monitor address[:port]')
  parser.add_argument('-i','--id', help='ceph client id')
  parser.add_argument('-k','--keyring', help='ceph client keyring file')
  parser.add_argument('-V','--version', help='show version and exit', action='store_true')
  parser.add_argument('-H','--host', help='osd host', required=True)
  parser.add_argument('-I','--osdid', help='osd id', required=False)
  parser.add_argument('-C','--crit', help='Number of failed OSDs to trigger critical (default=2)',type=int,default=2, required=False)
  parser.add_argument('-o','--out', help='check osds that are set OUT', default=False, action='store_true', required=False)
  parser.add_argument('-a','--cephadm', help='uses cephadm to execute the command', action='store_true')
  args = parser.parse_args()

  # validate args
  cephadm_exec = args.admexe if args.admexe else CEPH_ADM_COMMAND
  ceph_exec = args.exe if args.exe else CEPH_COMMAND
  if args.cephadm:
    if not os.path.exists(cephadm_exec):
      print("ERROR: cephadm executable '%s' doesn't exist" % cephadm_exec)
      return STATUS_UNKNOWN
  else:
    if not os.path.exists(ceph_exec):
      print("ERROR: ceph executable '%s' doesn't exist" % ceph_exec)
      return STATUS_UNKNOWN

  if args.version:
    print('version %s' % __version__)
    return STATUS_OK

  if args.conf and not os.path.exists(args.conf):
    print("OSD ERROR: ceph conf file '%s' doesn't exist" % args.conf)
    return STATUS_UNKNOWN

  if args.keyring and not os.path.exists(args.keyring):
    print("OSD ERROR: keyring file '%s' doesn't exist" % args.keyring)
    return STATUS_UNKNOWN

  if not args.osdid:
    args.osdid = '[^ ]*'

  if not args.host:
    print("OSD ERROR: no OSD hostname given")
    return STATUS_UNKNOWN

  # build command
  ceph_cmd = [ceph_exec]

  if args.cephadm:
    # Prepend the command with the cephadm binary and the shell command
    ceph_cmd = [cephadm_exec, 'shell']

    if args.keyring:
      ceph_cmd.append('-v')
      ceph_cmd.append('%s:%s:ro' % (args.keyring, args.keyring))
    ceph_cmd.append('--')
    ceph_cmd.append(ceph_exec)

  if args.monaddress:
    ceph_cmd.append('-m')
    ceph_cmd.append(args.monaddress)
  if args.conf:
    ceph_cmd.append('-c')
    ceph_cmd.append(args.conf)
  if args.id:
    ceph_cmd.append('--id')
    ceph_cmd.append(args.id)
  if args.keyring:
    ceph_cmd.append('--keyring')
    ceph_cmd.append(args.keyring)
  ceph_cmd.append('osd')
  ceph_cmd.append('dump')

  # exec command
  p = subprocess.Popen(ceph_cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
  output, err = p.communicate()
  output = output.decode('utf8')

  if p.returncode != 0 or not output:
    print("OSD ERROR: %s" % err.decode('utf-8'))
    return STATUS_ERROR

  # Use the new matching function
  result = match_osds_for_host(output, args.host, args.osdid, args.conf)
  up = result['up']
  down = result['down']
  down_in = result['down_in']
  down_out = result['down_out']

  if down:
    print("OSD %s: Down OSD%s on %s: %s" % ('CRITICAL' if len(down)>=args.crit else 'WARNING' ,'s' if len(down)>1 else '', args.host, " ".join(down)))
    print("Up OSDs: " + " ".join(up))
    print("Down+In OSDs: " + " ".join(down_in))
    print("Down+Out OSDs: " + " ".join(down_out))
    print("| 'osd_up'=%d 'osd_down_in'=%d;;%d 'osd_down_out'=%d;;%d" % (len(up), len(down_in), args.crit, len(down_out), args.crit))
    if len(down)>=args.crit:
      return STATUS_ERROR
    else:
      return STATUS_WARNING

  if up:
    print("OSD OK")
    print("Up OSDs: " + " ".join(up))
    print("Down+In OSDs: " + " ".join(down_in))
    print("Down+Out OSDs: " + " ".join(down_out))
    print("| 'osd_up'=%d 'osd_down_in'=%d;;%d 'osd_down_out'=%d;;%d" % (len(up), len(down_in), args.crit, len(down_out), args.crit))
    return STATUS_OK

  print("OSD WARN: no OSD.%s found on host %s" % (args.osdid, args.host))
  return STATUS_WARNING

if __name__ == "__main__":
    sys.exit(main())
