Skip to content

Commit

Permalink
Merge pull request #194 from shink/master
Browse files Browse the repository at this point in the history
feat: add support for custom function
  • Loading branch information
hsluoyz authored May 22, 2021
2 parents 1104ca2 + d16820b commit be38837
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 103 deletions.
11 changes: 11 additions & 0 deletions examples/abac_rule_custom_function_model.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[request_definition]
r = sub, obj, act

[policy_definition]
p = sub_rule, obj, act

[policy_effect]
e = some(where (p.eft == allow))

[matchers]
m = eval(p.sub_rule) && r.act == p.act
2 changes: 2 additions & 0 deletions examples/abac_rule_custom_function_policy.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
p, r.sub.name == 'alice' && custom(r.obj), r.obj, GET
p, r.sub.age >= 18 && custom(r.obj), r.obj, GET
8 changes: 4 additions & 4 deletions examples/abac_rule_with_domains_policy.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
p, r.domain.equals("domain1"), admin, domain1, data1, read
p, r.domain.equals("domain1"), admin, domain1, data1, write
p, r.domain.equals("domain2"), admin, domain2, data2, read
p, r.domain.equals("domain2"), admin, domain2, data2, write
p, r.domain == 'domain1', admin, domain1, data1, read
p, r.domain == 'domain1', admin, domain1, data1, write
p, r.domain == 'domain2', admin, domain2, data2, read
p, r.domain == 'domain2', admin, domain2, data2, write
g, alice, admin, domain1
g, bob, admin, domain2
7 changes: 1 addition & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
<dependency>
<groupId>com.googlecode.aviator</groupId>
<artifactId>aviator</artifactId>
<version>4.1.2</version>
<version>5.2.5</version>
</dependency>
<dependency>
<groupId>com.github.seancfoley</groupId>
Expand All @@ -212,11 +212,6 @@
<version>1.19</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.codehaus.janino</groupId>
<artifactId>janino</artifactId>
<version>3.1.2</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/casbin/jcasbin/main/CoreEnforcer.java
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ private boolean enforce(String matcher, Object... rvals) {
for (AviatorFunction f : functions.values()) {
aviatorEval.addFunction(f);
}
fm.setAviatorEval(aviatorEval);

modelModCount = model.getModCount();
}
Expand Down Expand Up @@ -554,6 +555,7 @@ private boolean validateEnforceSection(String section, Object... rvals) {
*/
public void resetExpressionEvaluator() {
aviatorEval = null;
fm.setAviatorEval(null);
}

public boolean isAutoNotifyWatcher() {
Expand Down
13 changes: 8 additions & 5 deletions src/main/java/org/casbin/jcasbin/main/ManagementEnforcer.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

package org.casbin.jcasbin.main;

import com.googlecode.aviator.runtime.type.AviatorFunction;
import org.casbin.jcasbin.effect.Effect;
import org.casbin.jcasbin.model.Assertion;
import org.casbin.jcasbin.util.Util;
import org.casbin.jcasbin.util.function.CustomFunction;

import java.lang.reflect.Method;
import java.util.*;

/**
Expand Down Expand Up @@ -475,6 +474,7 @@ public boolean addNamedGroupingPolicy(String ptype, List<String> params) {
boolean ruleAdded = addPolicy("g", ptype, params);

aviatorEval = null;
fm.setAviatorEval(null);
return ruleAdded;
}

Expand Down Expand Up @@ -534,6 +534,7 @@ public boolean removeNamedGroupingPolicy(String ptype, List<String> params) {
boolean ruleRemoved = removePolicy("g", ptype, params);

aviatorEval = null;
fm.setAviatorEval(null);
return ruleRemoved;
}

Expand Down Expand Up @@ -561,18 +562,20 @@ public boolean removeFilteredNamedGroupingPolicy(String ptype, int fieldIndex, S
boolean ruleRemoved = removeFilteredPolicy("g", ptype, fieldIndex, fieldValues);

aviatorEval = null;
fm.setAviatorEval(null);
return ruleRemoved;
}

/**
* addFunction adds a customized function.
*
* @param name the name of the new function.
* @param function the function.
* @param name the name of the function.
* @param function the custom function.
*/
public void addFunction(String name, AviatorFunction function) {
public void addFunction(String name, CustomFunction function) {
fm.addFunction(name, function);
aviatorEval = null;
fm.setAviatorEval(null);
}

/**
Expand Down
26 changes: 26 additions & 0 deletions src/main/java/org/casbin/jcasbin/model/FunctionMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package org.casbin.jcasbin.model;

import com.googlecode.aviator.AviatorEvaluatorInstance;
import com.googlecode.aviator.runtime.type.AviatorFunction;
import org.casbin.jcasbin.util.function.*;

Expand All @@ -39,6 +40,31 @@ public void addFunction(String name, AviatorFunction function) {
fm.put(name, function);
}

/**
* setAviatorEval adds AviatorEvaluatorInstance to the custom function.
*
* @param name the name of the custom function.
* @param aviatorEval the AviatorEvaluatorInstance object.
*/
public void setAviatorEval(String name, AviatorEvaluatorInstance aviatorEval) {
if (fm.containsKey(name) && fm.get(name) instanceof CustomFunction) {
((CustomFunction) fm.get(name)).setAviatorEval(aviatorEval);
}
}

/**
* setAviatorEval adds AviatorEvaluatorInstance to all the custom function.
*
* @param aviatorEval the AviatorEvaluatorInstance object.
*/
public void setAviatorEval(AviatorEvaluatorInstance aviatorEval) {
for (AviatorFunction function : fm.values()) {
if (function instanceof CustomFunction) {
((CustomFunction) function).setAviatorEval(aviatorEval);
}
}
}

/**
* loadFunctionMap loads an initial function map.
*
Expand Down
92 changes: 14 additions & 78 deletions src/main/java/org/casbin/jcasbin/util/BuiltInFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package org.casbin.jcasbin.util;

import com.googlecode.aviator.AviatorEvaluator;
import com.googlecode.aviator.AviatorEvaluatorInstance;
import com.googlecode.aviator.runtime.function.AbstractFunction;
import com.googlecode.aviator.runtime.function.FunctionUtils;
import com.googlecode.aviator.runtime.type.AviatorBoolean;
Expand All @@ -23,20 +25,15 @@
import inet.ipaddr.IPAddress;
import inet.ipaddr.IPAddressString;
import org.casbin.jcasbin.rbac.RoleManager;
import org.codehaus.commons.compiler.CompileException;
import org.codehaus.janino.ExpressionEvaluator;

import java.lang.reflect.InvocationTargetException;
import java.util.*;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class BuiltInFunctions {

private static Pattern keyMatch2Pattern = Pattern.compile("(.*):[^/]+(.*)");
private static Pattern keyMatch3Pattern = Pattern.compile("(.*)\\{[^/]+}(.*)");
private static Pattern evalPattern = Pattern.compile("(?<=\\.).*?(?=\\.| )");

/**
* keyMatch determines whether key1 matches the pattern of key2 (similar to RESTful path), key2
Expand Down Expand Up @@ -363,81 +360,20 @@ public String getName() {
}

/**
* eval calculates the stringified boolean expression and return its result. The syntax of
* expressions is exactly the same as Java. Flaw: dynamically generated classes or non-static
* inner class cannot be used.
* eval calculates the stringified boolean expression and return its result.
*
* @author tldyl
* @since 2020-07-02
*
* @param eval Boolean expression.
* @param env Parameters.
* @return The result of the eval.
* @param eval the stringified boolean expression.
* @param env the key-value pair of the parameters in the expression.
* @param aviatorEval the AviatorEvaluatorInstance object which contains built-in functions and custom functions.
* @return the result of the eval.
*/
public static boolean eval(String eval, Map<String, Object> env) {
ExpressionEvaluator evaluator = new ExpressionEvaluator();
Map<String, Map<String, Object>> evalModels = getEvalModels(env);
try {
List<String> parameterNameList = new ArrayList<>();
List<Object> parameterValueList = new ArrayList<>();
List<Class<?>> parameterClassList = new ArrayList<>();
for (Map.Entry<String, Object> entry: env.entrySet()) {
parameterNameList.add(entry.getKey());
parameterValueList.add(entry.getValue());
parameterClassList.add(entry.getValue().getClass());
}
List<String> sortedSrc = new ArrayList<>(getReplaceTargets(evalModels));
sortedSrc.sort((o1, o2) -> o1.length() > o2.length() ? -1 : 1);
for (String s : sortedSrc) {
eval = eval.replace("." + s, "_" + s);
}
Matcher matcher = evalPattern.matcher(eval);
while (matcher.find()) {
for (int i = 0; i <= matcher.groupCount(); i++) {
eval = eval.replace(matcher.group(), obtainFieldGetMethodName(matcher.group()));
}
}
evaluator.setParameters(parameterNameList.toArray(new String[0]), parameterClassList.toArray(new Class[0]));
evaluator.cook(eval);
return (boolean) evaluator.evaluate(parameterValueList.toArray(new Object[0]));
} catch (InvocationTargetException | CompileException e) {
e.printStackTrace();
public static boolean eval(String eval, Map<String, Object> env, AviatorEvaluatorInstance aviatorEval) {
boolean res;
if (aviatorEval != null) {
res = (boolean) aviatorEval.execute(eval, env);
} else {
res = (boolean) AviatorEvaluator.execute(eval, env);
}
return false;
}

/**
* getEvalModels extracts the value from env and assemble it into a EvalModel object.
*
* @param env the map.
*/
private static Map<String, Map<String, Object>> getEvalModels(Map<String, Object> env) {
final Map<String, Map<String, Object>> evalModels = new HashMap<>();
for (final Entry<String, Object> entry : env.entrySet()) {
final String[] names = entry.getKey().split("_");
evalModels.computeIfAbsent(names[0], k -> new HashMap<>()).put(names[1], entry.getValue());
}
return evalModels;
}

private static Set<String> getReplaceTargets(Map<String, Map<String, Object>> evalModels) {
Set<String> ret = new HashSet<>();
for (final Entry<String, Map<String, Object>> entry : evalModels.entrySet()) {
ret.addAll(entry.getValue().keySet());
}
return ret;
}

/**
* Get the function name of its get method according to the field name.
* For example, the input parameter is "age", the output parameter is "getAge()"
*
* @param fieldName the file name.
* @return the function name of its get method.
*/
private static String obtainFieldGetMethodName(String fieldName) {
return new StringBuffer().append("get")
.append(fieldName.substring(0, 1).toUpperCase())
.append(fieldName.substring(1)).append("()").toString();
return res;
}
}
47 changes: 47 additions & 0 deletions src/main/java/org/casbin/jcasbin/util/function/CustomFunction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2020 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.casbin.jcasbin.util.function;

import com.googlecode.aviator.AviatorEvaluatorInstance;
import com.googlecode.aviator.runtime.function.AbstractFunction;

import java.util.Map;

/**
* @author: shink
*/
public abstract class CustomFunction extends AbstractFunction {

private AviatorEvaluatorInstance aviatorEval;

public String replaceTargets(String exp, Map<String, Object> env) {
for (String key : env.keySet()) {
int index;
if ((index = key.indexOf('_')) != -1) {
String s = key.substring(index + 1);
exp = exp.replace("." + s, "_" + s);
}
}
return exp;
}

public AviatorEvaluatorInstance getAviatorEval() {
return aviatorEval;
}

public void setAviatorEval(AviatorEvaluatorInstance aviatorEval) {
this.aviatorEval = aviatorEval;
}
}
14 changes: 8 additions & 6 deletions src/main/java/org/casbin/jcasbin/util/function/EvalFunc.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

package org.casbin.jcasbin.util.function;

import com.googlecode.aviator.runtime.function.AbstractFunction;
import com.googlecode.aviator.runtime.function.FunctionUtils;
import com.googlecode.aviator.runtime.type.AviatorBoolean;
import com.googlecode.aviator.runtime.type.AviatorObject;
Expand All @@ -24,14 +23,17 @@

/**
* EvalFunc is the wrapper for eval.
* @author tldyl
* @since 2020-07-02
* It extends CustomFunction, so it can be used in matcher and policy rule.
*
* @author shink
*/
public class EvalFunc extends AbstractFunction {
public class EvalFunc extends CustomFunction {

@Override
public AviatorObject call(Map<String, Object> env, AviatorObject arg1) {
String ev = FunctionUtils.getStringValue(arg1, env);
return AviatorBoolean.valueOf(BuiltInFunctions.eval(ev, env));
String eval = FunctionUtils.getStringValue(arg1, env);
eval = replaceTargets(eval, env);
return AviatorBoolean.valueOf(BuiltInFunctions.eval(eval, env, getAviatorEval()));
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/org/casbin/jcasbin/main/AbacAPIUnitTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void testEvalWithDomain() {
testDomainEnforce(e, "bob", "domain2", "data2", "read", true);
}

public static class TestEvalRule { //This class must be static.
public static class TestEvalRule {
private String name;
private int age;

Expand Down
Loading

0 comments on commit be38837

Please sign in to comment.