Skip to content

Commit

Permalink
fix authz checks from REST, do not attempt stripe creation if entitle…
Browse files Browse the repository at this point in the history
…ment manager not enabled, simplify

Signed-off-by: Sarah Funkhouser <147884153+golanglemonade@users.noreply.github.com>
  • Loading branch information
golanglemonade committed Nov 20, 2024
1 parent 49eaaee commit eb7a4b4
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 86 deletions.
133 changes: 68 additions & 65 deletions internal/ent/hooks/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,19 @@ func RegisterListeners(pool *soiree.EventPool) {
}
}

// handleCustomerCreate handles the creation of a customer in Stripe
// handleCustomerCreate handles the creation of a customer in Stripe when an OrganizationSetting is created or updated
func handleCustomerCreate(event soiree.Event) error {
client := event.Client().(*entgen.Client)

if client.EntitlementManager == nil {
log.Info().Msg("EntitlementManager not found on client, skipping customer creation")

return nil
}

props := event.Properties()
// TODO: add funcs for default unmarshalling of props and send all props to stripe
orgsettingID, exists := props["ID"]
orgSettingID, exists := props["ID"]
if !exists {
log.Info().Msg("organizationSetting ID not found in event properties")
return nil
Expand All @@ -116,54 +124,18 @@ func handleCustomerCreate(event soiree.Event) error {
if exists && stripeID != "" {
stripes := stripeID.(string)

stripecustomer, err := event.Client().(*entgen.Client).EntitlementManager.Client.Customers.Get(stripes, nil)
stripeCustomer, err := client.EntitlementManager.Client.Customers.Get(stripes, nil)
if err != nil {
log.Err(err).Msg("Failed to retrieve Stripe customer by ID")
return err
}

billingEmail, exists := props["billing_email"]
if exists && billingEmail != "" {
email := billingEmail.(string)
if stripecustomer.Email != email {
_, err := event.Client().(*entgen.Client).EntitlementManager.Client.Customers.Update(stripes, &stripe.CustomerParams{
Email: &email,
})
if err != nil {
log.Err(err).Msg("failed to update Stripe customer email")
return err
}
}
}
// check for updates to billing information and update the Stripe customer if necessary
if params, hasUpdate := checkForBillingUpdate(props, stripeCustomer); hasUpdate {
if _, err := client.EntitlementManager.Client.Customers.Update(stripes, params); err != nil {
log.Err(err).Interface("params", params).Msg("failed to update Stripe customer with new billing information")

billingPhone, exists := props["billing_phone"]
if exists && billingPhone != "" {
phone := billingPhone.(string)
if stripecustomer.Phone != phone {
_, err := event.Client().(*entgen.Client).EntitlementManager.Client.Customers.Update(stripes, &stripe.CustomerParams{
Phone: &phone,
})
if err != nil {
log.Err(err).Msg("failed to update Stripe customer phone")
return err
}
}
}

// TODO: split out address fields in ent schema
billingAddress, exists := props["billing_address"]
if exists && billingAddress != "" {
address := billingAddress.(string)
if stripecustomer.Address.Line1 != address {
_, err := event.Client().(*entgen.Client).EntitlementManager.Client.Customers.Update(stripes, &stripe.CustomerParams{
Address: &stripe.AddressParams{
Line1: &address,
},
})
if err != nil {
log.Err(err).Msg("failed to update Stripe customer address")
return err
}
return err
}
}
}
Expand All @@ -172,55 +144,48 @@ func handleCustomerCreate(event soiree.Event) error {
if exists && billingEmail != "" {
email := billingEmail.(string)

i := event.Client().(*entgen.Client).EntitlementManager.Client.Customers.List(&stripe.CustomerListParams{Email: &email})
i := client.EntitlementManager.Client.Customers.List(&stripe.CustomerListParams{Email: &email})

if !i.Next() {
customer, err := event.Client().(*entgen.Client).EntitlementManager.CreateCustomer(email)
var customerID string

if i.Next() {
customerID = i.Customer().ID
} else {
// if there is no customer with the email, create one
customer, err := client.EntitlementManager.CreateCustomer(email)
if err != nil {
log.Err(err).Msg("Failed to create Stripe customer")
return err
}

customerID = customer.ID

log.Debug().Msgf("Created Stripe customer with ID: %s", customer.ID)

if err := updateOrganizationSettingWithCustomerID(event.Context(), orgsettingID.(string), customer.ID, event.Client()); err != nil {
if err := updateOrganizationSettingWithCustomerID(event.Context(), orgSettingID.(string), customer.ID, event.Client()); err != nil {
log.Err(err).Msg("Failed to update OrganizationSetting with Stripe customer ID")
return err
}

log.Debug().Msgf("Updated OrganizationSetting with Stripe customer ID: %s", customer.ID)

subs, err := event.Client().(*entgen.Client).EntitlementManager.ListOrCreateSubscriptions(customer.ID)
if err != nil {
log.Err(err).Msg("failed to list or create Stripe subscriptions")
return err
}

checkout, err := event.Client().(*entgen.Client).EntitlementManager.CreateBillingPortalUpdateSession(subs.ID, customer.ID)
if err != nil {
log.Err(err).Msg("failed to create billing portal update session")
return err
}

log.Debug().Msgf("Created billing portal update session with URL %s", checkout.URL)
}

// TODO create ent db records / corresponding feature / plan records
subs, err := event.Client().(*entgen.Client).EntitlementManager.ListOrCreateSubscriptions(i.Customer().ID)
subs, err := client.EntitlementManager.ListOrCreateSubscriptions(customerID)
if err != nil {
log.Err(err).Msg("failed to list or create Stripe subscriptions")
return err
}

checkout, err := event.Client().(*entgen.Client).EntitlementManager.CreateBillingPortalUpdateSession(subs.ID, i.Customer().ID)
checkout, err := client.EntitlementManager.CreateBillingPortalUpdateSession(subs.ID, customerID)
if err != nil {
log.Err(err).Msg("failed to create billing portal update session")
return err
}

log.Debug().Msgf("Created billing portal update session with URL %s", checkout.URL)

if err := updateOrganizationSettingWithCustomerID(event.Context(), orgsettingID.(string), i.Customer().ID, event.Client()); err != nil {
if err := updateOrganizationSettingWithCustomerID(event.Context(), orgSettingID.(string), i.Customer().ID, event.Client()); err != nil {
log.Err(err).Msg("Failed to update OrganizationSetting with Stripe customer ID")
return err
}
Expand All @@ -241,3 +206,41 @@ func updateOrganizationSettingWithCustomerID(ctx context.Context, orgsettingID,

return nil
}

// checkForBillingUpdate checks for updates to billing information in the properties and returns a stripe.CustomerParams object with the updated information
// and a boolean indicating whether there are updates
func checkForBillingUpdate(props map[string]interface{}, stripeCustomer *stripe.Customer) (params *stripe.CustomerParams, hasUpdate bool) {
params = &stripe.CustomerParams{}

billingEmail, exists := props["billing_email"]
if exists && billingEmail != "" {
email := billingEmail.(string)
if stripeCustomer.Email != email {
params.Email = &email
hasUpdate = true
}
}

billingPhone, exists := props["billing_phone"]
if exists && billingPhone != "" {
phone := billingPhone.(string)
if stripeCustomer.Phone != phone {
params.Phone = &phone
hasUpdate = true
}
}

// TODO: split out address fields in ent schema
billingAddress, exists := props["billing_address"]
if exists && billingAddress != "" {
address := billingAddress.(string)
if stripeCustomer.Address.Line1 != address {
params.Address = &stripe.AddressParams{
Line1: &address,
}
hasUpdate = true
}
}

return
}
4 changes: 2 additions & 2 deletions internal/ent/hooks/mutationhelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/theopenlane/iam/auth"
"github.com/theopenlane/iam/fgax"

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/privacy/utils"
)

// getTuplesToAdd gets the tuples that need to be added to the authz service based on the edges that were added
Expand Down Expand Up @@ -212,7 +212,7 @@ func addTokenEditPermissions(ctx context.Context, oID string, objectType string)
log.Debug().Interface("request", req).
Msg("creating edit tuples for api token")

if _, err := generated.FromContext(ctx).Authz.WriteTupleKeys(ctx, []fgax.TupleKey{fgax.GetTupleKey(req)}, nil); err != nil {
if _, err := utils.AuthzClientFromContext(ctx).WriteTupleKeys(ctx, []fgax.TupleKey{fgax.GetTupleKey(req)}, nil); err != nil {
log.Error().Err(err).Msg("failed to create relationship tuple")

return ErrInternalServerError
Expand Down
4 changes: 2 additions & 2 deletions internal/ent/hooks/objectownedtuples.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/theopenlane/iam/auth"
"github.com/theopenlane/iam/fgax"

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/privacy/utils"
)

// HookObjectOwnedTuples is a hook that adds object owned tuples for the object being created
Expand Down Expand Up @@ -68,7 +68,7 @@ func HookObjectOwnedTuples(parents []string, skipUser bool) ent.Hook {

// write the tuples to the authz service
if len(addTuples) != 0 || len(removeTuples) != 0 {
if _, err := generated.FromContext(ctx).Authz.WriteTupleKeys(ctx, addTuples, removeTuples); err != nil {
if _, err := utils.AuthzClientFromContext(ctx).WriteTupleKeys(ctx, addTuples, removeTuples); err != nil {
return nil, err
}
}
Expand Down
5 changes: 3 additions & 2 deletions internal/ent/hooks/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/generated/hook"
"github.com/theopenlane/core/internal/ent/privacy/utils"
)

// HookTaskCreate runs on task create mutations to set default values that are not provided
Expand Down Expand Up @@ -59,7 +60,7 @@ func HookTaskAssignee() ent.Hook {
})

// get the current assignee and remove them
resp, err := generated.FromContext(ctx).Authz.ListUserRequest(ctx, fgax.ListRequest{
resp, err := utils.AuthzClientFromContext(ctx).ListUserRequest(ctx, fgax.ListRequest{
ObjectID: taskID,
ObjectType: strings.ToLower(m.Type()),
Relation: "assignee",
Expand Down Expand Up @@ -87,7 +88,7 @@ func HookTaskAssignee() ent.Hook {
}

// add the new assignee and remove the old assignee
if _, err := generated.FromContext(ctx).Authz.WriteTupleKeys(ctx, []fgax.TupleKey{addTuple}, deleteTuples); err != nil {
if _, err := utils.AuthzClientFromContext(ctx).WriteTupleKeys(ctx, []fgax.TupleKey{addTuple}, deleteTuples); err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions internal/ent/interceptors/auditlogs.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (

"github.com/theopenlane/iam/auth"

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/generated/intercept"
"github.com/theopenlane/core/internal/ent/privacy/utils"
)

// HistoryAccess is a traversal interceptor that checks if the user has the required role for the organization
Expand All @@ -36,7 +36,7 @@ func HistoryAccess(relation string, orgOwned, userOwed bool) ent.Interceptor {
for _, orgID := range au.OrganizationIDs {
req.ObjectID = orgID

allowed, err := generated.FromContext(ctx).Authz.CheckAccess(ctx, req)
allowed, err := utils.AuthzClientFromContext(ctx).CheckAccess(ctx, req)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/ent/interceptors/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
"github.com/theopenlane/iam/auth"
"github.com/theopenlane/iam/fgax"

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/generated/intercept"
"github.com/theopenlane/core/internal/ent/generated/privacy"
"github.com/theopenlane/core/internal/ent/privacy/utils"
)

// FilterListQuery filters any list query to only include the objects that the user has access to
Expand Down Expand Up @@ -54,7 +54,7 @@ func GetAuthorizedObjectIDs(ctx context.Context, objectType string) ([]string, e
ObjectType: strings.ToLower(objectType),
}

resp, err := generated.FromContext(ctx).Authz.ListObjectsRequest(ctx, req)
resp, err := utils.AuthzClientFromContext(ctx).ListObjectsRequest(ctx, req)
if err != nil {
return []string{}, nil
}
Expand Down
6 changes: 3 additions & 3 deletions internal/ent/privacy/rule/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/generated/privacy"
"github.com/theopenlane/core/pkg/middleware/transaction"
"github.com/theopenlane/core/internal/ent/privacy/utils"
)

// CheckOrgAccess checks if the authenticated user has access to the organization
Expand All @@ -34,7 +34,7 @@ func CheckOrgAccess(ctx context.Context, relation string) error {
ObjectID: au.OrganizationID,
}

access, err := transaction.FromContext(ctx).Authz.CheckOrgAccess(ctx, ac)
access, err := utils.AuthzClientFromContext(ctx).CheckOrgAccess(ctx, ac)
if err != nil {
return err
}
Expand Down Expand Up @@ -170,7 +170,7 @@ func CanCreateObjectsInOrg() privacy.MutationRuleFunc {
Relation: relation,
}

access, err := generated.FromContext(ctx).Authz.CheckOrgAccess(ctx, ac)
access, err := utils.AuthzClientFromContext(ctx).CheckOrgAccess(ctx, ac)
if err != nil {
return privacy.Skipf("unable to check access, %s", err.Error())
}
Expand Down
21 changes: 21 additions & 0 deletions internal/ent/privacy/utils/utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package utils

import (
"context"

"github.com/theopenlane/iam/fgax"

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/generated/privacy"
"github.com/theopenlane/core/pkg/middleware/transaction"
)

// NewMutationPolicyWithoutNil is creating a new slice of `privacy.MutationPolicy` by
Expand All @@ -17,3 +23,18 @@ func NewMutationPolicyWithoutNil(source privacy.MutationPolicy) privacy.Mutation

return newSlice
}

// AuthzClientFromContext returns the authz client from the context if it exists
func AuthzClientFromContext(ctx context.Context) *fgax.Client {
client := generated.FromContext(ctx)
if client != nil {
return &client.Authz
}

tx := transaction.FromContext(ctx)
if tx != nil {
return &tx.Authz
}

return nil
}
4 changes: 4 additions & 0 deletions internal/ent/schema/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
emixin "github.com/theopenlane/entx/mixin"

"github.com/theopenlane/iam/entfga"
"github.com/theopenlane/iam/fgax"

"github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/generated/privacy"
Expand Down Expand Up @@ -225,6 +226,9 @@ func (Organization) Policy() ent.Policy {
privacy.OrganizationQueryRuleFunc(func(ctx context.Context, q *generated.OrganizationQuery) error {
return q.CheckAccess(ctx)
}),
privacy.OrganizationQueryRuleFunc(func(ctx context.Context, q *generated.OrganizationQuery) error {
return rule.CheckOrgAccess(ctx, fgax.CanView)
}),
privacy.AlwaysDenyRule(), // Deny all other users
},
}
Expand Down
2 changes: 1 addition & 1 deletion internal/httpserve/handlers/checkout.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (h *Handler) CheckoutSessionHandler(ctx echo.Context) error {

settings, err := h.getOrgSettingByOrgID(reqCtx, orgID)
if err != nil {
log.Error().Err(err).Msg("unable to get organization settings by org id")
log.Error().Err(err).Str("organization_id", orgID).Msg("unable to get organization settings by org id")

return h.BadRequest(ctx, err)
}
Expand Down
4 changes: 1 addition & 3 deletions internal/httpserve/handlers/ent.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,7 @@ func (h *Handler) getOrgByID(ctx context.Context, id string) (*ent.Organization,

// getOrgSettingByOrgID returns the organization settings from an organization ID and context
func (h *Handler) getOrgSettingByOrgID(ctx context.Context, orgID string) (*ent.OrganizationSetting, error) {
orgGetCtx := privacy.DecisionContext(ctx, privacy.Allow)

org, err := h.getOrgByID(orgGetCtx, orgID)
org, err := h.getOrgByID(ctx, orgID)
if err != nil {
log.Error().Err(err).Msg("error retrieving organization")

Expand Down
Loading

0 comments on commit eb7a4b4

Please sign in to comment.