# Copyright 2024-2025 Ping Identity Corporation. All Rights Reserved
#
# This code is to be used exclusively in connection with Ping Identity
# Corporation software or services. Ping Identity Corporation only offers
# such software or services to legal entities who have entered into a
# binding license agreement with Ping Identity Corporation.
# -*- coding: utf-8 -*-

# Python imports
import time
import json
import os

# load tasks used by configuration file
from pyrock.tasks.deployment.datainit import MakeUsersTask
from pyrock.tasks.idm.recon import TestIdmConnector, TestRCSConnection, ReconTask, ReconResultTask
from pyrock.tasks.idm.general import GenerateUserGroups, DumpIDMIDWithAPITask
from pyrock.tasks.scenario.gatling import GatlingTask
from pyrock.tasks.scenario.ds_sdk import DSLdapModifyTask, DSModRateTask, DSConfigStandaloneTask
from pyrock.tasks.deployment.validation import ValidationNumUsers
from pyrock.tasks.deployment.installation import DeployOverseerTask
from pyrock.lib.scheduler.tasks.ShellTask import ShellTask, ShellTaskOverseer
from shared.lib.components.am import AM
from pyrock.lib.scheduler.tasks.StepTask import StepTask
from pyrock.lib.PyRockRun import get_pyrock_run
from pyrock.lib.report.json.tasks.ReportSimulation import ReportSimulation
from shared.lib.platform_utils import PlatformUtils
from shared.lib.utils.constants import IS_TENANT
from shared.lib.utils.exception import FailException

pyrock_run = get_pyrock_run()


class UpdateRcsClientSecretTask(StepTask):
    def pre(self):
        if self.source != "controller":
            raise FailException("Task must be executed on controller")
        if not isinstance(self.target, AM):
            raise ValueError(f"Target {self.target} must be an AM component")

    @staticmethod
    def step1():
        pyrock_run.log("Reset RCSClient password")
        PlatformUtils(components=pyrock_run.get_components()).set_rcs_client_secret()


class ManagedApplicationStepTask(StepTask):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.idm = None
        self.authorization_header = None
        # for report
        self.num_members = 0  # Number of members in a role
        self.elapsed_time = 0  # Time it took for the REST call to assign role to application
        self.num_requests = 0  # Number of HTTP (API) requests
        self.ok_items = 0  # How many HTTP (API) requests were processed successfully

    def pre(self):
        """
        Initialise the test
        """
        self.idm = pyrock_run.get_component("idm")
        # renew OAuth2 token to get whole hour before token expiration
        self.authorization_header = pyrock_run.get_component("am").get_user_oauth2_headers(force_renew=True)

    def build_report(self, num_members=None):
        """
        Build the results for reporting
        """
        self.report_simulation.target_hostname = pyrock_run.get_component(self.target_name).hostname
        self.report_simulation.tool_name = "-"
        self.report_simulation.stats = {"global": {}}
        global_stats = self.report_simulation.stats["global"]
        # i.e. processed entries (name kept from Gatling-based results)
        global_stats["numberOfRequests"] = {"ok": self.ok_items}
        global_stats["duration"] = {"total": self.elapsed_time}
        # Here we intentionally don't use num_requests because we are interested in getting the
        # average throughput of members per role assigned to an application
        avg_throughput = num_members / self.elapsed_time
        avg_response_time = round((1000 / avg_throughput), ndigits=3)
        avg_throughput = round(avg_throughput, ndigits=3)
        global_stats["meanNumberOfRequestsPerSecond"] = {"total": avg_throughput}
        global_stats["meanResponseTime"] = {"total": avg_response_time}

        # "generic results" section
        # Here we have to use num_requests as num_members does not represent a request
        self.report_simulation.num_requests = self.num_requests
        self.report_simulation.num_requests_pass = self.ok_items
        self.report_simulation.num_requests_percent_pass = round(self.ok_items * 100.0 / self.num_requests, ndigits=3)
        self.report_simulation.avg_num_of_requests_per_second = avg_throughput
        self.report_simulation.avg_response_time = avg_response_time


class CreateRCSConnectorServerTask(ManagedApplicationStepTask):

    def step1(self):
        """Create the RCS connector server in IDM"""
        result = self.idm.create_connector_server(headers=self.authorization_header)
        if not result:
            self.set_result_fail()


class CreateApplicationIgaTask(ManagedApplicationStepTask):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.num_applications = None

    def pre(self):
        super().pre()
        self.is_option("num_applications", required=True)
        self.num_applications = self.get_option("num_applications")
        pyrock_run.log(
            f'Number of Applications to be created: {self.num_applications} (+ "AuthoritativeSource" '
            'and a special "Application0" apps)'
        )

    def step1(self):
        """Create IGA Application"""
        result = self.idm.create_application_iga(
            headers=self.authorization_header, max_applications=self.num_applications
        )

        # The app_ids are written to a file for future use
        with open(os.path.join(self.get_task_dir(), "app_ids.txt"), "w") as file:
            for id in result:
                file.write(f"{id}\n")

        if len(result) != (self.num_applications + 2):
            pyrock_run.log(
                f'Number of created apps is not correct - expected: {self.num_applications} (+ "AuthoritativeSource" '
                f'and a special "Application0" apps), actual: {len(result)}',
                level="ERROR",
            )
            self.set_result_fail()


class DeleteApplicationTask(ManagedApplicationStepTask):
    """
    Delete All Managed Applications.  Keep in mind that in case of a tenant only the application
    definition is deleted, not the actual mapping and provisioner files.
    """

    def step1(self):
        """
        Delete All Managed Applications
        """
        # Clean up only in IDC
        if IS_TENANT:
            result = self.idm.delete_all_managed_applications(headers=self.authorization_header)
            if not result:
                self.set_result_fail()


class AssignApplicationToRoleTask(ManagedApplicationStepTask):
    """
    Assign Application(s) to Role(s)
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.json_file = None

    def pre(self):
        super().pre()  # Call the parent pre() also
        self.report_simulation = ReportSimulation()
        self.is_option("assignment_file", required=True)
        self.json_file = self.get_option("assignment_file")
        pyrock_run.log(f"Assignment file: {self.json_file}")

    def step1(self):
        """Read role<->application mapping from json file and then discover the _id of the application and
        then pass it to method to assign the role to the application.
        """
        try:
            with open(f"{self.json_file}", "r") as file:
                json_data = json.load(file)
            file.close()
        except FileNotFoundError:
            pyrock_run.log(f"The file {self.json_file} was not found.")
        except json.JSONDecodeError:
            print(f"Error decoding JSON data in {self.json_file}.")

        for item in json_data:
            role_name = item["name"]
            app_name = item["assignment"]
            app_id = self.idm.get_id_by_name(name=app_name, headers=self.authorization_header)
            if not app_id:
                pyrock_run.log(f"Application {app_name} not found. Skipping...")
                continue
            self.start_time = time.time()
            self.num_requests += 1
            result = self.idm.assign_application_role(
                application_id=app_id, role_name=role_name, headers=self.authorization_header
            )
            # Get the cumulative elapsed time and member count because
            # "avg_throughput" is being calculated for all applications
            self.elapsed_time += time.time() - self.start_time
            self.num_members += self.idm.get_role_member_count(role_name=role_name, headers=self.authorization_header)
            pyrock_run.log(f"Number of members in role {role_name} is {self.num_members}")
            if result:
                self.ok_items += 1
            else:
                self.set_result_fail()

    def step2(self):
        """build report"""
        if self.num_members > 0:
            self.build_report(self.num_members)
        else:
            pyrock_run.log(f"Skipping Result Generation because number of role members is {self.num_members}")
