# Copyright 2022-2025 Ping Identity Corporation. All Rights Reserved
#
# This code is to be used exclusively in connection with Ping Identity
# Corporation software or services. Ping Identity Corporation only offers
# such software or services to legal entities who have entered into a
# binding license agreement with Ping Identity Corporation.

# -*- coding: utf-8 -*-

# Python imports
import time

# Framework imports - load tasks used by configuration file
from shared.lib.utils.login_session import LoginSession
from shared.lib.platform_utils import PlatformUtils
from shared.lib.utils.exception import FailException
from pyrock.lib.PyRockRun import get_pyrock_run
from shared.lib.utils.duration import format_time
from pyrock.lib.scheduler.tasks.StepTask import StepTask
from pyrock.lib.report.json.tasks.ReportSimulation import ReportSimulation
from pyrock.tasks.deployment.configuration_idm import PrepareWorkloadTask
from pyrock.tasks.scenario.gatling import GatlingTask
from pyrock.tasks.idm.general import DumpIDMIDWithAPITask
from pyrock.tasks.deployment.installation import DeployOverseerTask


pyrock_run = get_pyrock_run()

# general TODO for all tasks that are meant to be a scenario - pass the result to the report


class PlatformGroupStepTask(StepTask):
    num_groups_keyword = "num-groups"
    platform_utils = None

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.num_groups = None
        self.group_name_base = None
        self.condition = None
        self.headers = None

        # for report
        self.elapsed_time = 0
        self.ok_items = 0  # how many items were processed successfully
        self.report_simulation = None

    def pre(self):
        if pyrock_run.is_component(self.target_name) and self.target.component_type == "idm":
            pyrock_run.log(f"target ({self.target_name}) is an IDM component")
        elif pyrock_run.is_pod(self.target_name) and self.target.component.component_type == "ds":
            pyrock_run.log(f"target ({self.target_name}) is an IDM pod")
        else:
            raise FailException(f"target ({self.target_name}) must be an IDM component or pod")

        self.platform_utils = PlatformUtils(
            pyrock_run.get_components(), deployment=pyrock_run.deployment.default_deployment
        )

        self.headers = pyrock_run.get_component("am").get_user_oauth2_headers(force_renew=True)
        self.headers["Accept-API-Version"] = "resource=1.0"

    def build_report(self):
        self.report_simulation = ReportSimulation()
        self.report_simulation.target_hostname = pyrock_run.get_component(self.target_name).hostname
        self.report_simulation.options = {"num_groups": self.num_groups, "base_name": self.group_name_base}
        self.report_simulation.tool_name = "-"
        self.report_simulation.stats = {"global": {}}
        global_stats = self.report_simulation.stats["global"]
        # i.e. processed entries (name kept from Gatling-based results)
        global_stats["numberOfRequests"] = {
            "ok": self.ok_items,
            "ko": self.ok_items - self.num_groups,
            "total": self.num_groups,
        }
        global_stats["duration"] = {"total": self.elapsed_time}
        avg_throughput = self.num_groups / self.elapsed_time
        avg_response_time = round((1000 / avg_throughput), ndigits=3)
        avg_throughput = round(avg_throughput, ndigits=3)
        global_stats["meanNumberOfRequestsPerSecond"] = {"total": avg_throughput}
        global_stats["meanResponseTime"] = {"total": avg_response_time}

        # "generic results" section
        self.report_simulation.num_requests = self.num_groups
        self.report_simulation.num_requests_pass = self.ok_items
        self.report_simulation.num_requests_percent_pass = round(self.ok_items * 100.0 / self.num_groups, ndigits=3)
        self.report_simulation.avg_num_of_requests_per_second = avg_throughput
        self.report_simulation.avg_response_time = avg_response_time

    def post(self):
        """build report"""
        if not self.is_option("no-report"):
            self.build_report()
        self.set_result_pass()


class CreatePlatformGroupTask(PlatformGroupStepTask):
    """
    Task for creating both static and dynamic platforms groups
    """

    def pre(self):
        super().pre()

        # build the base of group name with pattern:
        # {group_name_prefix}_{static/dynamic_group}
        # underscore is appended if more than one group (num-groups param) is required
        if self.is_option("condition"):
            self.group_name_base = "dynamic_group"
            self.condition = self.get_option(option_name="condition")
        else:
            self.group_name_base = "static_group"
        group_name_prefix = str(self.get_option(option_name="group-name-prefix", default=""))
        if len(group_name_prefix) > 0:
            self.group_name_base = f"{group_name_prefix}_{self.group_name_base}"

        self.group_name_base = f"{self.group_name_base}_"  # at the very end a number will be appended

        # num-groups param handling
        if self.is_option(self.num_groups_keyword, required=True):
            self.num_groups = self.get_option(option_name=self.num_groups_keyword)

    def step1(self):
        """Search if IDM managed groups already exist"""
        # TODO: search in IDM with "sw" (starts with) "static_group_" to get any valid group to be found
        #  (to cover cases like not all groups were created in previous run, only few were left after previous run)
        group_name = f"{self.group_name_base}{self.num_groups - 1}"
        print("")
        print(f"Searching IDM managed group {group_name}")
        print("")
        response = self.target.search_managed_group(headers=self.headers, group_name=group_name)
        print(response)
        print("")
        if len(response) > 0:
            print("Groups already recorded, we should skip next steps")
            self.skip_next_steps()
            # TODO - shouldn't we rather delete existing groups in this case? There are possible leftovers..
        else:
            print("Groups not yet recorded, we should add them")

    def step2(self):
        """Add IDM managed groups into IDM"""
        self.start_time = time.time()
        group_description = f'Description for {"static" if self.condition is None else "dynamic"} group '
        for i in range(0, self.num_groups, 1):
            group_name = f"{self.group_name_base}{i}"
            print("-" * 55)
            print()
            response = self.platform_utils.create_managed_group(
                group_name=group_name,
                group_description=f"{group_description}{group_name}",
                expected_status=[201, 412],
                condition=self.condition,
                headers=self.headers,
            )
        self.elapsed_time = time.time() - self.start_time

    def step3(self):
        """List IDM managed groups"""
        self.ok_items = 0
        # TODO - we should do search_managed_group to make IDM filter groups for us
        #        also filtering and counting the results in a loop will not be necessary
        response = self.platform_utils.list_managed_groups(self.headers)
        groups = response.json()["result"]
        for group in groups:
            group_name = group["name"]
            if group_name.startswith(self.group_name_base):
                self.ok_items += 1
            print(f"- {group_name}")
        if self.num_groups == self.ok_items:
            print(f"PASS : found expected number of groups ({self.num_groups})")
        else:
            print(f"FAIL : found {self.ok_items} number of groups (expected {self.num_groups})")
            self.set_result_fail()


class DeletePlatformGroupTask(PlatformGroupStepTask):

    def pre(self):
        super().pre()

        # we don't need to make a pattern for prefix like for groups creation,
        # also "condition" param is not used here, so it is not even possible
        self.group_name_base = str(self.get_option(option_name="group-name-prefix", default=""))

        self.num_groups = 0

    def step1(self):
        """Delete IDM managed groups"""
        response = self.platform_utils.list_managed_groups(self.headers)
        groups = response.json()["result"]
        self.start_time = time.time()
        for group in groups:
            group_name = group["name"]
            group_id = group["_id"]
            print("-" * 55)
            if self.group_name_base is not None and group_name.startswith(self.group_name_base):
                self.num_groups += 1
                try:
                    response = self.platform_utils.delete_managed_group(group_id=group_id, headers=self.headers)
                    self.ok_items += 1
                except:
                    pass
            else:
                print(
                    f"- [{format_time()}] Skip group name {group_name} / id {group_id} "
                    f"as it does not starts with base of the name {self.group_name_base}"
                )
            print("")

        self.elapsed_time = time.time() - self.start_time


class DsDisplayUsersFilteredByStateTask(StepTask):
    search_filter = None

    def pre(self):
        """Checking task config"""
        if pyrock_run.is_component(self.target_name) and self.target.component_type == "ds":
            pyrock_run.log(f"target ({self.target_name}) is a ds component")
        elif pyrock_run.is_pod(self.target_name) and self.target.component.component_type == "ds":
            pyrock_run.log(f"target ({self.target_name}) is a ds pod")
        else:
            raise FailException(f"target ({self.target_name}) must be a DS component or pod")
        self.search_filter = str(self.get_option(option_name="search-filter", default=None))

    def step1(self):
        """Display number of users filtered by state attribute"""
        result = self.target.component.ldapsearch(
            pod=self.target, base_dn=self.target.component.base_dn, search_filter=self.search_filter, attributes="dn"
        )
        result_list = result.stdout.splitlines()
        while "" in result_list:
            result_list.remove("")
        print(f'number of users having attribute "{self.search_filter}" : {len(result_list)}')
