#!/usr/bin/env python3

# Tool for canonical RISC-V architecture string.
# Copyright (C) 2011-2025 Free Software Foundation, Inc.
# Contributed by Andrew Waterman (andrew@sifive.com).
#
# This file is part of GCC.
#
# GCC is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3, or (at your option)
# any later version.
#
# GCC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GCC; see the file COPYING3.  If not see
# <http://www.gnu.org/licenses/>.

from __future__ import print_function
import sys
import argparse
import collections
import itertools
import re
import os
from functools import reduce

SUPPORTED_ISA_SPEC = ["2.2", "20190608", "20191213"]
CANONICAL_ORDER = "imafdqlcbkjtpvnh"
LONG_EXT_PREFIXES = ['z', 's', 'h', 'x']

def parse_define_riscv_ext(content):
  """Parse DEFINE_RISCV_EXT macros using position-based parsing."""
  extensions = []

  # Find all DEFINE_RISCV_EXT blocks
  pattern = r'DEFINE_RISCV_EXT\s*\('
  matches = []

  pos = 0
  while True:
    match = re.search(pattern, content[pos:])
    if not match:
      break

    start_pos = pos + match.start()
    paren_count = 0
    current_pos = pos + match.end() - 1  # Start at the opening parenthesis

    # Find the matching closing parenthesis
    while current_pos < len(content):
      if content[current_pos] == '(':
        paren_count += 1
      elif content[current_pos] == ')':
        paren_count -= 1
        if paren_count == 0:
          break
      current_pos += 1

    if paren_count == 0:
      # Extract the content inside parentheses
      macro_content = content[pos + match.end():current_pos]
      ext_data = parse_macro_arguments(macro_content)
      if ext_data:
        extensions.append(ext_data)

    pos = current_pos + 1

  return extensions

def parse_macro_arguments(macro_content):
  """Parse the arguments of a DEFINE_RISCV_EXT macro."""
  # Remove comments /* ... */
  cleaned_content = re.sub(r'/\*[^*]*\*/', '', macro_content)

  # Split arguments by comma, but respect nested structures
  args = []
  current_arg = ""
  paren_count = 0
  brace_count = 0
  in_string = False
  escape_next = False

  for char in cleaned_content:
    if escape_next:
      current_arg += char
      escape_next = False
      continue

    if char == '\\':
      escape_next = True
      current_arg += char
      continue

    if char == '"' and not escape_next:
      in_string = not in_string
      current_arg += char
      continue

    if in_string:
      current_arg += char
      continue

    if char == '(':
      paren_count += 1
    elif char == ')':
      paren_count -= 1
    elif char == '{':
      brace_count += 1
    elif char == '}':
      brace_count -= 1
    elif char == ',' and paren_count == 0 and brace_count == 0:
      args.append(current_arg.strip())
      current_arg = ""
      continue

    current_arg += char

  # Add the last argument
  if current_arg.strip():
    args.append(current_arg.strip())

  # We need at least 6 arguments to get DEP_EXTS (position 5)
  if len(args) < 6:
    return None

  ext_name = args[0].strip()
  dep_exts_arg = args[5].strip()  # DEP_EXTS is at position 5

  # Parse dependency extensions from the DEP_EXTS argument
  deps = parse_dep_exts(dep_exts_arg)

  return {
    'name': ext_name,
    'dep_exts': deps
  }

def parse_dep_exts(dep_exts_str):
  """Parse the DEP_EXTS argument to extract dependency list with conditions."""
  # Remove outer parentheses if present
  dep_exts_str = dep_exts_str.strip()
  if dep_exts_str.startswith('(') and dep_exts_str.endswith(')'):
    dep_exts_str = dep_exts_str[1:-1].strip()

  # Remove outer braces if present
  if dep_exts_str.startswith('{') and dep_exts_str.endswith('}'):
    dep_exts_str = dep_exts_str[1:-1].strip()

  if not dep_exts_str:
    return []

  deps = []

  # First, find and process conditional dependencies
  conditional_pattern = r'\{\s*"([^"]+)"\s*,\s*(\[.*?\]\s*\([^)]*\)\s*->\s*bool.*?)\}'
  conditional_matches = []

  for match in re.finditer(conditional_pattern, dep_exts_str, re.DOTALL):
    ext_name = match.group(1)
    condition_code = match.group(2)
    deps.append({'ext': ext_name, 'type': 'conditional', 'condition': condition_code})
    # The conditional_pattern RE matches only the first code block enclosed
    # in braces.
    #
    # Extend the match to the condition block's closing brace, encompassing
    # all code blocks,  by simply trying to match the numbers of opening
    # and closing braces.  While crude, this avoids writing a complicated
    # parse here.
    closing_braces_left = condition_code.count('{') - condition_code.count('}')
    condition_end = match.end()
    while closing_braces_left > 0:
      condition_end = dep_exts_str.find('}', condition_end)
      closing_braces_left -= 1
    conditional_matches.append((match.start(), condition_end))

  # Remove conditional dependency blocks from the string
  remaining_str = dep_exts_str
  for start, end in reversed(conditional_matches):  # Reverse order to maintain indices
    remaining_str = remaining_str[:start] + remaining_str[end:]

  # Now handle simple quoted strings in the remaining text
  for match in re.finditer(r'"([^"]+)"', remaining_str):
    deps.append({'ext': match.group(1), 'type': 'simple'})

  # Remove duplicates while preserving order
  seen = set()
  unique_deps = []
  for dep in deps:
    key = (dep['ext'], dep['type'])
    if key not in seen:
      seen.add(key)
      unique_deps.append(dep)

  return unique_deps

def evaluate_conditional_dependency(ext, dep, xlen, current_exts):
  """Evaluate whether a conditional dependency should be included."""
  ext_name = dep['ext']
  condition = dep['condition']
  # Parse the condition based on known patterns
  if ext_name == 'zcf' and ext in ['zca', 'c', 'zce']:
    # zcf depends on RV32 and F extension
    return xlen == 32 and 'f' in current_exts
  elif ext_name == 'zcd' and ext in ['zca', 'c']:
    # zcd depends on D extension
    return 'd' in current_exts
  elif ext_name == 'c' and ext in ['zca']:
    # Special case for zca -> c conditional dependency
    if xlen == 32:
      if 'd' in current_exts:
        return 'zcf' in current_exts and 'zcd' in current_exts
      elif 'f' in current_exts:
        return 'zcf' in current_exts
      else:
        return True
    elif xlen == 64:
      if 'd' in current_exts:
        return 'zcd' in current_exts
      else:
        return True
    return False
  else:
    # Report error for unhandled conditional dependencies
    import sys
    print(f"ERROR: Unhandled conditional dependency: '{ext_name}' with condition:", file=sys.stderr)
    print(f"  Condition code: {condition[:100]}...", file=sys.stderr)
    print(f"  Current context: xlen={xlen}, exts={sorted(current_exts)}", file=sys.stderr)
    # For now, return False to be safe
    return False

def resolve_dependencies(arch_parts, xlen):
  """Resolve all dependencies including conditional ones."""
  current_exts = set(arch_parts)
  implied_deps = set()

  # Keep resolving until no new dependencies are found
  changed = True
  while changed:
    changed = False
    new_deps = set()

    for ext in current_exts | implied_deps:
      if ext in IMPLIED_EXT:
        for dep in IMPLIED_EXT[ext]:
          if dep['type'] == 'simple':
            if dep['ext'] not in current_exts and dep['ext'] not in implied_deps:
              new_deps.add(dep['ext'])
              changed = True
          elif dep['type'] == 'conditional':
            should_include = evaluate_conditional_dependency(ext, dep, xlen, current_exts | implied_deps)
            if should_include:
              if dep['ext'] not in current_exts and dep['ext'] not in implied_deps:
                new_deps.add(dep['ext'])
                changed = True

    implied_deps.update(new_deps)

  return implied_deps

def parse_def_file(file_path, script_dir, processed_files=None, collect_all=False):
  """Parse a single .def file and recursively process #include directives."""
  if processed_files is None:
    processed_files = set()

  # Avoid infinite recursion
  if file_path in processed_files:
    return ({}, set()) if collect_all else {}
  processed_files.add(file_path)

  implied_ext = {}
  all_extensions = set() if collect_all else None

  if not os.path.exists(file_path):
    return (implied_ext, all_extensions) if collect_all else implied_ext

  with open(file_path, 'r') as f:
    content = f.read()

  # Process #include directives first
  include_pattern = r'#include\s+"([^"]+)"'
  includes = re.findall(include_pattern, content)

  for include_file in includes:
    include_path = os.path.join(script_dir, include_file)
    if collect_all:
      included_ext, included_all = parse_def_file(include_path, script_dir, processed_files, collect_all)
      implied_ext.update(included_ext)
      all_extensions.update(included_all)
    else:
      included_ext = parse_def_file(include_path, script_dir, processed_files, collect_all)
      implied_ext.update(included_ext)

  # Parse DEFINE_RISCV_EXT blocks using position-based parsing
  parsed_exts = parse_define_riscv_ext(content)

  for ext_data in parsed_exts:
    ext_name = ext_data['name']
    deps = ext_data['dep_exts']

    if collect_all:
      all_extensions.add(ext_name)

    if deps:
      implied_ext[ext_name] = deps

  return (implied_ext, all_extensions) if collect_all else implied_ext

def parse_def_files():
  """Parse RISC-V extension definition files starting from riscv-ext.def."""
  # Get directory containing this script
  try:
    script_dir = os.path.dirname(os.path.abspath(__file__))
  except NameError:
    # When __file__ is not defined (e.g., interactive mode)
    script_dir = os.getcwd()

  # Start with the main definition file
  main_def_file = os.path.join(script_dir, 'riscv-ext.def')
  return parse_def_file(main_def_file, script_dir)

def get_all_extensions():
  """Get all supported extensions and their implied extensions."""
  # Get directory containing this script
  try:
    script_dir = os.path.dirname(os.path.abspath(__file__))
  except NameError:
    # When __file__ is not defined (e.g., interactive mode)
    script_dir = os.getcwd()

  # Start with the main definition file
  main_def_file = os.path.join(script_dir, 'riscv-ext.def')
  return parse_def_file(main_def_file, script_dir, collect_all=True)

#
# IMPLIED_EXT(ext) -> implied extension list.
# This is loaded dynamically from .def files
#
IMPLIED_EXT = parse_def_files()

def load_profiles():
    profiles = set()
    def_path = os.path.join(os.path.dirname(__file__), "riscv-profiles.def")
    with open(def_path) as f:
        for line in f:
            line = line.strip()
            if line.startswith("RISCV_PROFILE"):
                # Format: RISCV_PROFILE("rva20u64", "rv64imafd...")
                parts = line.split('"')
                if len(parts) >= 2:
                    profiles.add(parts[1])   # Compare PROFILE_NAME
    return profiles

SUPPORTED_PROFILES = load_profiles()

def is_profile_arch(arch):
    return arch in SUPPORTED_PROFILES

def arch_canonicalize(arch, isa_spec):
  # TODO: Support extension version.
  is_isa_spec_2p2 = isa_spec == '2.2'
  new_arch = ""
  extra_long_ext = []
  std_exts = []
  if arch[:5] in ['rv32e', 'rv32i', 'rv32g', 'rv64e', 'rv64i', 'rv64g']:
    new_arch = arch[:5].replace("g", "i")
    if arch[:5] in ['rv32g', 'rv64g']:
      std_exts = ['m', 'a', 'f', 'd']
      if not is_isa_spec_2p2:
        extra_long_ext = ['zicsr', 'zifencei']
  else:
    raise Exception("Unexpected arch: `%s`" % arch[:5])

  # Find any Z, S, H or X
  long_ext_prefixes_idx = map(lambda x: arch.find(x), LONG_EXT_PREFIXES)

  # Filter out any non-existent index.
  long_ext_prefixes_idx = list(filter(lambda x: x != -1, long_ext_prefixes_idx))
  if long_ext_prefixes_idx:
    first_long_ext_idx = min(long_ext_prefixes_idx)
    long_exts = arch[first_long_ext_idx:].split("_")
    std_exts += list(arch[5:first_long_ext_idx])
  else:
    long_exts = []
    std_exts += list(arch[5:])

  long_exts += extra_long_ext

  #
  # Handle implied extensions using new conditional logic.
  #
  # Extract xlen from architecture string
  # TODO: We should support profile here.
  if arch.startswith('rv32'):
    xlen = 32
  elif arch.startswith('rv64'):
    xlen = 64
  else:
    raise Exception("Unsupported prefix `%s`" % arch)

  # Get all current extensions
  current_exts = std_exts + long_exts

  # Resolve dependencies
  implied_deps = resolve_dependencies(current_exts, xlen)

  # Filter out zicsr for ISA spec 2.2
  if is_isa_spec_2p2:
    implied_deps.discard('zicsr')

  # Add implied dependencies to long_exts
  for dep in implied_deps:
    if dep not in current_exts:
      long_exts.append(dep)

  # Single letter extension might appear in the long_exts list,
  # because we just append extensions list to the arch string.
  std_exts += list(filter(lambda x:len(x) == 1, long_exts))

  def longext_sort (exts):
    if not exts.startswith("zxm") and exts.startswith("z"):
      # If "Z" extensions are named, they should be ordered first by CANONICAL.
      if exts[1] not in CANONICAL_ORDER:
        raise Exception("Unsupported extension `%s`" % exts)
      canonical_sort = CANONICAL_ORDER.index(exts[1])
    else:
      canonical_sort = -1
    return (exts.startswith("x"), exts.startswith("zxm"),
            LONG_EXT_PREFIXES.index(exts[0]), canonical_sort, exts[1:])

  # Removing duplicates.
  long_exts = list(set(long_exts))

  # Multi-letter extension must be in lexicographic order.
  long_exts = list(sorted(filter(lambda x:len(x) != 1, long_exts),
                          key=longext_sort))

  # Put extensions in canonical order.
  for ext in CANONICAL_ORDER:
    if ext in std_exts:
      new_arch += ext

  # Check every extension is processed.
  for ext in std_exts:
    if ext == '_':
      continue
    if ext not in CANONICAL_ORDER:
      raise Exception("Unsupported extension `%s`" % ext)

  # Concat rest of the multi-char extensions.
  if long_exts:
    new_arch += "_" + "_".join(long_exts)

  return new_arch

def dump_all_extensions():
  """Dump all extensions and their implied extensions."""
  implied_ext, all_extensions = get_all_extensions()

  print("All supported RISC-V extensions:")
  print("=" * 60)

  if not all_extensions:
    print("No extensions found.")
    return

  # Sort all extensions for consistent output
  sorted_all = sorted(all_extensions)

  # Print all extensions with their dependencies (if any)
  for ext_name in sorted_all:
    if ext_name in implied_ext:
      deps = implied_ext[ext_name]
      dep_strs = []
      for dep in deps:
        if dep['type'] == 'simple':
          dep_strs.append(dep['ext'])
        else:
          dep_strs.append(f"{dep['ext']}*")  # Mark conditional deps with *
      print(f"{ext_name:15} -> {', '.join(dep_strs)}")
    else:
      print(f"{ext_name:15} -> (no dependencies)")

  print(f"\nTotal extensions: {len(all_extensions)}")
  print(f"Extensions with dependencies: {len(implied_ext)}")
  print(f"Extensions without dependencies: {len(all_extensions) - len(implied_ext)}")

def run_unit_tests():
  """Run unit tests using pytest dynamically imported."""
  try:
    import pytest
  except ImportError:
    print("Error: pytest is required for running unit tests.")
    print("Please install pytest: pip install pytest")
    return 1

  # Define test functions
  def test_basic_arch_parsing():
    """Test basic architecture string parsing."""
    result = arch_canonicalize("rv64i", "20191213")
    assert result == "rv64i"

  def test_simple_extensions():
    """Test simple extension handling."""
    result = arch_canonicalize("rv64im", "20191213")
    assert "zmmul" in result

  def test_implied_extensions():
    """Test implied extension resolution."""
    result = arch_canonicalize("rv64imaf", "20191213")
    assert "zicsr" in result

  def test_conditional_dependencies():
    """Test conditional dependency evaluation."""
    # Test RV32 with F extension should include zcf when c is present
    result = arch_canonicalize("rv32ifc", "20191213")
    parts = result.split("_")
    if "c" in parts:
      assert "zca" in parts
      if "f" in parts:
        assert "zcf" in parts

  def test_parse_dep_exts():
    """Test dependency parsing function."""
    # Test simple dependency
    deps = parse_dep_exts('{"ext1", "ext2"}')
    assert len(deps) == 2
    assert deps[0]['ext'] == 'ext1'
    assert deps[0]['type'] == 'simple'

  def test_evaluate_conditional_dependency():
    """Test conditional dependency evaluation."""
    # Test zcf condition for RV32 with F
    dep = {'ext': 'zcf', 'type': 'conditional', 'condition': 'test'}
    result = evaluate_conditional_dependency('zce', dep, 32, {'f'})
    assert result == True

    # Test zcf condition for RV64 with F (should be False)
    result = evaluate_conditional_dependency('zce', dep, 64, {'f'})
    assert result == False

  def test_parse_define_riscv_ext():
    """Test DEFINE_RISCV_EXT parsing."""
    content = '''
    DEFINE_RISCV_EXT(
      /* NAME */ test,
      /* UPPERCASE_NAME */ TEST,
      /* FULL_NAME */ "Test extension",
      /* DESC */ "",
      /* URL */ ,
      /* DEP_EXTS */ ({"dep1", "dep2"}),
      /* SUPPORTED_VERSIONS */ ({{1, 0}}),
      /* FLAG_GROUP */ test,
      /* BITMASK_GROUP_ID */ 0,
      /* BITMASK_BIT_POSITION*/ 0,
      /* EXTRA_EXTENSION_FLAGS */ 0)
    '''

    extensions = parse_define_riscv_ext(content)
    assert len(extensions) == 1
    assert extensions[0]['name'] == 'test'
    assert len(extensions[0]['dep_exts']) == 2

  def test_parse_long_condition_block():
    """Test condition block containing several code blocks."""
    result = arch_canonicalize("rv32ec", "20191213")
    assert "rv32ec_zca" in result

  # Collect test functions
  test_functions = [
    test_basic_arch_parsing,
    test_simple_extensions,
    test_implied_extensions,
    test_conditional_dependencies,
    test_parse_dep_exts,
    test_evaluate_conditional_dependency,
    test_parse_define_riscv_ext,
    test_parse_long_condition_block
  ]

  # Run tests manually first, then optionally with pytest
  print("Running unit tests...")

  passed = 0
  failed = 0

  for i, test_func in enumerate(test_functions):
    try:
      print(f"  Running {test_func.__name__}...", end=" ")
      test_func()
      print("PASSED")
      passed += 1
    except Exception as e:
      print(f"FAILED: {e}")
      failed += 1

  print(f"\nTest Summary: {passed} passed, {failed} failed")

  if failed == 0:
    print("\nAll tests passed!")
    return 0
  else:
    print(f"\n{failed} test(s) failed!")
    return 1

if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument('-misa-spec', type=str,
                      default='20191213',
                      choices=SUPPORTED_ISA_SPEC)
  parser.add_argument('--dump-all', action='store_true',
                      help='Dump all extensions and their implied extensions')
  parser.add_argument('--selftest', action='store_true',
                      help='Run unit tests using pytest')
  parser.add_argument('arch_strs', nargs='*',
                      help='Architecture strings to canonicalize')

  args = parser.parse_args()

  if args.dump_all:
    dump_all_extensions()
  elif args.selftest:
    sys.exit(run_unit_tests())
  elif args.arch_strs:
    for arch in args.arch_strs:
      if is_profile_arch(arch):
          print(arch)
      else:
          print(arch_canonicalize(arch, args.misa_spec))
  else:
    parser.print_help()
    sys.exit(1)
