# stripsCSPPlanner.py - CSP planner where actions are represented using STRIPS
# AIFCA Python3 code Version 0.9.4 Documentation at http://aipython.org
# Download the zip file and read aipython.pdf for documentation

# Artificial Intelligence: Foundations of Computational Agents http://artint.info
# Copyright David L Poole and Alan K Mackworth 2017-2022.
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# See: http://creativecommons.org/licenses/by-nc-sa/4.0/deed.en

from cspProblem import Variable, CSP, Constraint

class CSP_from_STRIPS(CSP):
    """A CSP where:
    * CSP variables are constructed for each feature and time, and each action and time
    * the dynamics are specified by the STRIPS representation of actions
    """

    def __init__(self,  planning_problem, number_stages=2):
        prob_domain = planning_problem.prob_domain
        initial_state = planning_problem.initial_state
        goal = planning_problem.goal
        # self.action_vars[t] is the action variable for time t
        self.action_vars = [Variable(f"Action{t}", prob_domain.actions)
                                for t in range(number_stages)]
        # feat_time_var[f][t] is the variable for feature f at time t
        feat_time_var = {feat: [Variable(f"{feat}_{t}",dom)
                                         for t in range(number_stages+1)]
                           for (feat,dom) in prob_domain.feature_domain_dict.items()}
                         
        # initial state constraints:
        constraints = [Constraint((feat_time_var[feat][0],), is_(val))
                            for (feat,val) in initial_state.items()]
                            
        # goal constraints on the final state:
        constraints += [Constraint((feat_time_var[feat][number_stages],),
                                        is_(val))
                            for (feat,val) in goal.items()]
                            
        # precondition constraints:
        constraints += [Constraint((feat_time_var[feat][t], self.action_vars[t]),
                                   if_(val,act))  # feat@t==val if action@t==act
                            for act in prob_domain.actions
                            for (feat,val) in act.preconds.items()
                            for t in range(number_stages)]
                            
        # effect constraints:
        constraints += [Constraint((feat_time_var[feat][t+1], self.action_vars[t]),
                                   if_(val,act))  # feat@t+1==val if action@t==act
                            for act in prob_domain.actions
                            for feat,val in act.effects.items()
                            for t in range(number_stages)]
        # frame constraints:
        
        constraints += [Constraint((feat_time_var[feat][t], self.action_vars[t], feat_time_var[feat][t+1]),
                                   eq_if_not_in_({act for act in prob_domain.actions
                                                  if feat in act.effects}))
                            for feat in prob_domain.feature_domain_dict
                            for t in range(number_stages) ]
        variables = set(self.action_vars) | {feat_time_var[feat][t]
                                            for feat in prob_domain.feature_domain_dict
                                            for t in range(number_stages+1)}
        CSP.__init__(self, variables, constraints)

    def extract_plan(self,soln):
        return [soln[a] for a in self.action_vars]

def is_(val):
    """returns a function that is true when it is it applied to val.
    """
    #return lambda x: x == val
    def is_fun(x):
        return x == val
    is_fun.__name__ = "value_is_"+str(val)
    return is_fun

def if_(v1,v2):
    """if the second argument is v2, the first argument must be v1"""
    #return lambda x1,x2: x1==v1 if x2==v2 else True
    def if_fun(x1,x2): 
        return x1==v1 if x2==v2 else True
    if_fun.__name__ = "if x2 is "+str(v2)+" then x1 is "+str(v1)
    return if_fun

def eq_if_not_in_(actset):
    """first and third arguments are equal if action is not in actset"""
    # return lambda x1, a, x2: x1==x2 if a not in actset else True
    def eq_if_not_fun(x1, a, x2):
        return x1==x2 if a not in actset else True
    eq_if_not_fun.__name__ = "first and third arguments are equal if action is not in "+str(actset)
    return eq_if_not_fun

def con_plan(prob,horizon):
    """finds a plan for problem prob given horizon.
    """
    csp = CSP_from_STRIPS(prob, horizon)
    sol = Con_solver(csp).solve_one()
    return csp.extract_plan(sol) if sol else sol
    
from searchGeneric import Searcher
from stripsProblem import delivery_domain
from cspConsistency import Search_with_AC_from_CSP, Con_solver
from stripsProblem import Planning_problem, problem0, problem1, problem2, blocks1, blocks2, blocks3

# Problem 0
# con_plan(problem0,1) # should it succeed?
# con_plan(problem0,2) # should it succeed?
# con_plan(problem0,3) # should it succeed?
# To use search to enumerate solutions
#searcher0a = Searcher(Search_with_AC_from_CSP(CSP_from_STRIPS(problem0, 1)))
#print(searcher0a.search())  # returns path to solution

## Problem 1
# con_plan(problem1,5) # should it succeed?
# con_plan(problem1,4) # should it succeed?
## To use search to enumerate solutions:
#searcher15a = Searcher(Search_with_AC_from_CSP(CSP_from_STRIPS(problem1, 5)))
#print(searcher15a.search())  # returns path to solution

## Problem 2
#con_plan(problem2, 6)  # should fail??
#con_plan(problem2, 7)  # should succeed???

## Example 6.13
problem3 = Planning_problem(delivery_domain, 
                            {'SWC':True, 'RHC':False}, {'SWC':False})
#con_plan(problem3,2)  # Horizon of 2
#con_plan(problem3,3)  # Horizon of 3

problem4 = Planning_problem(delivery_domain,{'SWC':True},
                               {'SWC':False, 'MW':False, 'RHM':False})

# For the stochastic local search:
#from cspSLS import SLSearcher, Runtime_distribution
# cspplanning15 = CSP_from_STRIPS(problem1, 5) # should succeed
#se0 = SLSearcher(cspplanning15); print(se0.search(100000,0.5))
#p = Runtime_distribution(cspplanning15)
#p.plot_runs(1000,1000,0.7)  # warning will take a few minutes

