1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11 Clock, Page, Pagination,
12 upstream_oauth2::{
13 UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14 },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26 DatabaseError, DatabaseInconsistencyError,
27 filter::{Filter, StatementExt},
28 iden::UpstreamOAuthProviders,
29 pagination::QueryBuilderExt,
30 tracing::ExecuteExt,
31};
32
33pub struct PgUpstreamOAuthProviderRepository<'c> {
36 conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40 pub fn new(conn: &'c mut PgConnection) -> Self {
43 Self { conn }
44 }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50 upstream_oauth_provider_id: Uuid,
51 issuer: Option<String>,
52 human_name: Option<String>,
53 brand_name: Option<String>,
54 scope: String,
55 client_id: String,
56 encrypted_client_secret: Option<String>,
57 token_endpoint_signing_alg: Option<String>,
58 token_endpoint_auth_method: String,
59 id_token_signed_response_alg: String,
60 fetch_userinfo: bool,
61 userinfo_signed_response_alg: Option<String>,
62 created_at: DateTime<Utc>,
63 disabled_at: Option<DateTime<Utc>>,
64 claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65 jwks_uri_override: Option<String>,
66 authorization_endpoint_override: Option<String>,
67 token_endpoint_override: Option<String>,
68 userinfo_endpoint_override: Option<String>,
69 discovery_mode: String,
70 pkce_mode: String,
71 response_mode: Option<String>,
72 additional_parameters: Option<Json<Vec<(String, String)>>>,
73 forward_login_hint: bool,
74}
75
76impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
77 type Error = DatabaseInconsistencyError;
78
79 #[allow(clippy::too_many_lines)]
80 fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
81 let id = value.upstream_oauth_provider_id.into();
82 let scope = value.scope.parse().map_err(|e| {
83 DatabaseInconsistencyError::on("upstream_oauth_providers")
84 .column("scope")
85 .row(id)
86 .source(e)
87 })?;
88 let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
89 DatabaseInconsistencyError::on("upstream_oauth_providers")
90 .column("token_endpoint_auth_method")
91 .row(id)
92 .source(e)
93 })?;
94 let token_endpoint_signing_alg = value
95 .token_endpoint_signing_alg
96 .map(|x| x.parse())
97 .transpose()
98 .map_err(|e| {
99 DatabaseInconsistencyError::on("upstream_oauth_providers")
100 .column("token_endpoint_signing_alg")
101 .row(id)
102 .source(e)
103 })?;
104 let id_token_signed_response_alg =
105 value.id_token_signed_response_alg.parse().map_err(|e| {
106 DatabaseInconsistencyError::on("upstream_oauth_providers")
107 .column("id_token_signed_response_alg")
108 .row(id)
109 .source(e)
110 })?;
111
112 let userinfo_signed_response_alg = value
113 .userinfo_signed_response_alg
114 .map(|x| x.parse())
115 .transpose()
116 .map_err(|e| {
117 DatabaseInconsistencyError::on("upstream_oauth_providers")
118 .column("userinfo_signed_response_alg")
119 .row(id)
120 .source(e)
121 })?;
122
123 let authorization_endpoint_override = value
124 .authorization_endpoint_override
125 .map(|x| x.parse())
126 .transpose()
127 .map_err(|e| {
128 DatabaseInconsistencyError::on("upstream_oauth_providers")
129 .column("authorization_endpoint_override")
130 .row(id)
131 .source(e)
132 })?;
133
134 let token_endpoint_override = value
135 .token_endpoint_override
136 .map(|x| x.parse())
137 .transpose()
138 .map_err(|e| {
139 DatabaseInconsistencyError::on("upstream_oauth_providers")
140 .column("token_endpoint_override")
141 .row(id)
142 .source(e)
143 })?;
144
145 let userinfo_endpoint_override = value
146 .userinfo_endpoint_override
147 .map(|x| x.parse())
148 .transpose()
149 .map_err(|e| {
150 DatabaseInconsistencyError::on("upstream_oauth_providers")
151 .column("userinfo_endpoint_override")
152 .row(id)
153 .source(e)
154 })?;
155
156 let jwks_uri_override = value
157 .jwks_uri_override
158 .map(|x| x.parse())
159 .transpose()
160 .map_err(|e| {
161 DatabaseInconsistencyError::on("upstream_oauth_providers")
162 .column("jwks_uri_override")
163 .row(id)
164 .source(e)
165 })?;
166
167 let discovery_mode = value.discovery_mode.parse().map_err(|e| {
168 DatabaseInconsistencyError::on("upstream_oauth_providers")
169 .column("discovery_mode")
170 .row(id)
171 .source(e)
172 })?;
173
174 let pkce_mode = value.pkce_mode.parse().map_err(|e| {
175 DatabaseInconsistencyError::on("upstream_oauth_providers")
176 .column("pkce_mode")
177 .row(id)
178 .source(e)
179 })?;
180
181 let response_mode = value
182 .response_mode
183 .map(|x| x.parse())
184 .transpose()
185 .map_err(|e| {
186 DatabaseInconsistencyError::on("upstream_oauth_providers")
187 .column("response_mode")
188 .row(id)
189 .source(e)
190 })?;
191
192 let additional_authorization_parameters = value
193 .additional_parameters
194 .map(|Json(x)| x)
195 .unwrap_or_default();
196
197 Ok(UpstreamOAuthProvider {
198 id,
199 issuer: value.issuer,
200 human_name: value.human_name,
201 brand_name: value.brand_name,
202 scope,
203 client_id: value.client_id,
204 encrypted_client_secret: value.encrypted_client_secret,
205 token_endpoint_auth_method,
206 token_endpoint_signing_alg,
207 id_token_signed_response_alg,
208 fetch_userinfo: value.fetch_userinfo,
209 userinfo_signed_response_alg,
210 created_at: value.created_at,
211 disabled_at: value.disabled_at,
212 claims_imports: value.claims_imports.0,
213 authorization_endpoint_override,
214 token_endpoint_override,
215 userinfo_endpoint_override,
216 jwks_uri_override,
217 discovery_mode,
218 pkce_mode,
219 response_mode,
220 additional_authorization_parameters,
221 forward_login_hint: value.forward_login_hint,
222 })
223 }
224}
225
226impl Filter for UpstreamOAuthProviderFilter<'_> {
227 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
228 sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
229 Expr::col((
230 UpstreamOAuthProviders::Table,
231 UpstreamOAuthProviders::DisabledAt,
232 ))
233 .is_null()
234 .eq(enabled)
235 }))
236 }
237}
238
239#[async_trait]
240impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
241 type Error = DatabaseError;
242
243 #[tracing::instrument(
244 name = "db.upstream_oauth_provider.lookup",
245 skip_all,
246 fields(
247 db.query.text,
248 upstream_oauth_provider.id = %id,
249 ),
250 err,
251 )]
252 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
253 let res = sqlx::query_as!(
254 ProviderLookup,
255 r#"
256 SELECT
257 upstream_oauth_provider_id,
258 issuer,
259 human_name,
260 brand_name,
261 scope,
262 client_id,
263 encrypted_client_secret,
264 token_endpoint_signing_alg,
265 token_endpoint_auth_method,
266 id_token_signed_response_alg,
267 fetch_userinfo,
268 userinfo_signed_response_alg,
269 created_at,
270 disabled_at,
271 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
272 jwks_uri_override,
273 authorization_endpoint_override,
274 token_endpoint_override,
275 userinfo_endpoint_override,
276 discovery_mode,
277 pkce_mode,
278 response_mode,
279 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
280 forward_login_hint
281 FROM upstream_oauth_providers
282 WHERE upstream_oauth_provider_id = $1
283 "#,
284 Uuid::from(id),
285 )
286 .traced()
287 .fetch_optional(&mut *self.conn)
288 .await?;
289
290 let res = res
291 .map(UpstreamOAuthProvider::try_from)
292 .transpose()
293 .map_err(DatabaseError::from)?;
294
295 Ok(res)
296 }
297
298 #[tracing::instrument(
299 name = "db.upstream_oauth_provider.add",
300 skip_all,
301 fields(
302 db.query.text,
303 upstream_oauth_provider.id,
304 upstream_oauth_provider.issuer = params.issuer,
305 upstream_oauth_provider.client_id = %params.client_id,
306 ),
307 err,
308 )]
309 async fn add(
310 &mut self,
311 rng: &mut (dyn RngCore + Send),
312 clock: &dyn Clock,
313 params: UpstreamOAuthProviderParams,
314 ) -> Result<UpstreamOAuthProvider, Self::Error> {
315 let created_at = clock.now();
316 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
317 tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
318
319 sqlx::query!(
320 r#"
321 INSERT INTO upstream_oauth_providers (
322 upstream_oauth_provider_id,
323 issuer,
324 human_name,
325 brand_name,
326 scope,
327 token_endpoint_auth_method,
328 token_endpoint_signing_alg,
329 id_token_signed_response_alg,
330 fetch_userinfo,
331 userinfo_signed_response_alg,
332 client_id,
333 encrypted_client_secret,
334 claims_imports,
335 authorization_endpoint_override,
336 token_endpoint_override,
337 userinfo_endpoint_override,
338 jwks_uri_override,
339 discovery_mode,
340 pkce_mode,
341 response_mode,
342 forward_login_hint,
343 created_at
344 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
345 $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
346 "#,
347 Uuid::from(id),
348 params.issuer.as_deref(),
349 params.human_name.as_deref(),
350 params.brand_name.as_deref(),
351 params.scope.to_string(),
352 params.token_endpoint_auth_method.to_string(),
353 params
354 .token_endpoint_signing_alg
355 .as_ref()
356 .map(ToString::to_string),
357 params.id_token_signed_response_alg.to_string(),
358 params.fetch_userinfo,
359 params
360 .userinfo_signed_response_alg
361 .as_ref()
362 .map(ToString::to_string),
363 ¶ms.client_id,
364 params.encrypted_client_secret.as_deref(),
365 Json(¶ms.claims_imports) as _,
366 params
367 .authorization_endpoint_override
368 .as_ref()
369 .map(ToString::to_string),
370 params
371 .token_endpoint_override
372 .as_ref()
373 .map(ToString::to_string),
374 params
375 .userinfo_endpoint_override
376 .as_ref()
377 .map(ToString::to_string),
378 params.jwks_uri_override.as_ref().map(ToString::to_string),
379 params.discovery_mode.as_str(),
380 params.pkce_mode.as_str(),
381 params.response_mode.as_ref().map(ToString::to_string),
382 params.forward_login_hint,
383 created_at,
384 )
385 .traced()
386 .execute(&mut *self.conn)
387 .await?;
388
389 Ok(UpstreamOAuthProvider {
390 id,
391 issuer: params.issuer,
392 human_name: params.human_name,
393 brand_name: params.brand_name,
394 scope: params.scope,
395 client_id: params.client_id,
396 encrypted_client_secret: params.encrypted_client_secret,
397 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
398 token_endpoint_auth_method: params.token_endpoint_auth_method,
399 id_token_signed_response_alg: params.id_token_signed_response_alg,
400 fetch_userinfo: params.fetch_userinfo,
401 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
402 created_at,
403 disabled_at: None,
404 claims_imports: params.claims_imports,
405 authorization_endpoint_override: params.authorization_endpoint_override,
406 token_endpoint_override: params.token_endpoint_override,
407 userinfo_endpoint_override: params.userinfo_endpoint_override,
408 jwks_uri_override: params.jwks_uri_override,
409 discovery_mode: params.discovery_mode,
410 pkce_mode: params.pkce_mode,
411 response_mode: params.response_mode,
412 additional_authorization_parameters: params.additional_authorization_parameters,
413 forward_login_hint: params.forward_login_hint,
414 })
415 }
416
417 #[tracing::instrument(
418 name = "db.upstream_oauth_provider.delete_by_id",
419 skip_all,
420 fields(
421 db.query.text,
422 upstream_oauth_provider.id = %id,
423 ),
424 err,
425 )]
426 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
427 {
430 let span = info_span!(
431 "db.oauth2_client.delete_by_id.authorization_sessions",
432 upstream_oauth_provider.id = %id,
433 { DB_QUERY_TEXT } = tracing::field::Empty,
434 );
435 sqlx::query!(
436 r#"
437 DELETE FROM upstream_oauth_authorization_sessions
438 WHERE upstream_oauth_provider_id = $1
439 "#,
440 Uuid::from(id),
441 )
442 .record(&span)
443 .execute(&mut *self.conn)
444 .instrument(span)
445 .await?;
446 }
447
448 {
451 let span = info_span!(
452 "db.oauth2_client.delete_by_id.links",
453 upstream_oauth_provider.id = %id,
454 { DB_QUERY_TEXT } = tracing::field::Empty,
455 );
456 sqlx::query!(
457 r#"
458 DELETE FROM upstream_oauth_links
459 WHERE upstream_oauth_provider_id = $1
460 "#,
461 Uuid::from(id),
462 )
463 .record(&span)
464 .execute(&mut *self.conn)
465 .instrument(span)
466 .await?;
467 }
468
469 let res = sqlx::query!(
470 r#"
471 DELETE FROM upstream_oauth_providers
472 WHERE upstream_oauth_provider_id = $1
473 "#,
474 Uuid::from(id),
475 )
476 .traced()
477 .execute(&mut *self.conn)
478 .await?;
479
480 DatabaseError::ensure_affected_rows(&res, 1)
481 }
482
483 #[tracing::instrument(
484 name = "db.upstream_oauth_provider.add",
485 skip_all,
486 fields(
487 db.query.text,
488 upstream_oauth_provider.id = %id,
489 upstream_oauth_provider.issuer = params.issuer,
490 upstream_oauth_provider.client_id = %params.client_id,
491 ),
492 err,
493 )]
494 async fn upsert(
495 &mut self,
496 clock: &dyn Clock,
497 id: Ulid,
498 params: UpstreamOAuthProviderParams,
499 ) -> Result<UpstreamOAuthProvider, Self::Error> {
500 let created_at = clock.now();
501
502 let created_at = sqlx::query_scalar!(
503 r#"
504 INSERT INTO upstream_oauth_providers (
505 upstream_oauth_provider_id,
506 issuer,
507 human_name,
508 brand_name,
509 scope,
510 token_endpoint_auth_method,
511 token_endpoint_signing_alg,
512 id_token_signed_response_alg,
513 fetch_userinfo,
514 userinfo_signed_response_alg,
515 client_id,
516 encrypted_client_secret,
517 claims_imports,
518 authorization_endpoint_override,
519 token_endpoint_override,
520 userinfo_endpoint_override,
521 jwks_uri_override,
522 discovery_mode,
523 pkce_mode,
524 response_mode,
525 additional_parameters,
526 forward_login_hint,
527 ui_order,
528 created_at
529 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
530 $12, $13, $14, $15, $16, $17, $18, $19, $20,
531 $21, $22, $23, $24)
532 ON CONFLICT (upstream_oauth_provider_id)
533 DO UPDATE
534 SET
535 issuer = EXCLUDED.issuer,
536 human_name = EXCLUDED.human_name,
537 brand_name = EXCLUDED.brand_name,
538 scope = EXCLUDED.scope,
539 token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
540 token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
541 id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
542 fetch_userinfo = EXCLUDED.fetch_userinfo,
543 userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
544 disabled_at = NULL,
545 client_id = EXCLUDED.client_id,
546 encrypted_client_secret = EXCLUDED.encrypted_client_secret,
547 claims_imports = EXCLUDED.claims_imports,
548 authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
549 token_endpoint_override = EXCLUDED.token_endpoint_override,
550 userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
551 jwks_uri_override = EXCLUDED.jwks_uri_override,
552 discovery_mode = EXCLUDED.discovery_mode,
553 pkce_mode = EXCLUDED.pkce_mode,
554 response_mode = EXCLUDED.response_mode,
555 additional_parameters = EXCLUDED.additional_parameters,
556 forward_login_hint = EXCLUDED.forward_login_hint,
557 ui_order = EXCLUDED.ui_order
558 RETURNING created_at
559 "#,
560 Uuid::from(id),
561 params.issuer.as_deref(),
562 params.human_name.as_deref(),
563 params.brand_name.as_deref(),
564 params.scope.to_string(),
565 params.token_endpoint_auth_method.to_string(),
566 params
567 .token_endpoint_signing_alg
568 .as_ref()
569 .map(ToString::to_string),
570 params.id_token_signed_response_alg.to_string(),
571 params.fetch_userinfo,
572 params
573 .userinfo_signed_response_alg
574 .as_ref()
575 .map(ToString::to_string),
576 ¶ms.client_id,
577 params.encrypted_client_secret.as_deref(),
578 Json(¶ms.claims_imports) as _,
579 params
580 .authorization_endpoint_override
581 .as_ref()
582 .map(ToString::to_string),
583 params
584 .token_endpoint_override
585 .as_ref()
586 .map(ToString::to_string),
587 params
588 .userinfo_endpoint_override
589 .as_ref()
590 .map(ToString::to_string),
591 params.jwks_uri_override.as_ref().map(ToString::to_string),
592 params.discovery_mode.as_str(),
593 params.pkce_mode.as_str(),
594 params.response_mode.as_ref().map(ToString::to_string),
595 Json(¶ms.additional_authorization_parameters) as _,
596 params.forward_login_hint,
597 params.ui_order,
598 created_at,
599 )
600 .traced()
601 .fetch_one(&mut *self.conn)
602 .await?;
603
604 Ok(UpstreamOAuthProvider {
605 id,
606 issuer: params.issuer,
607 human_name: params.human_name,
608 brand_name: params.brand_name,
609 scope: params.scope,
610 client_id: params.client_id,
611 encrypted_client_secret: params.encrypted_client_secret,
612 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
613 token_endpoint_auth_method: params.token_endpoint_auth_method,
614 id_token_signed_response_alg: params.id_token_signed_response_alg,
615 fetch_userinfo: params.fetch_userinfo,
616 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
617 created_at,
618 disabled_at: None,
619 claims_imports: params.claims_imports,
620 authorization_endpoint_override: params.authorization_endpoint_override,
621 token_endpoint_override: params.token_endpoint_override,
622 userinfo_endpoint_override: params.userinfo_endpoint_override,
623 jwks_uri_override: params.jwks_uri_override,
624 discovery_mode: params.discovery_mode,
625 pkce_mode: params.pkce_mode,
626 response_mode: params.response_mode,
627 additional_authorization_parameters: params.additional_authorization_parameters,
628 forward_login_hint: params.forward_login_hint,
629 })
630 }
631
632 #[tracing::instrument(
633 name = "db.upstream_oauth_provider.disable",
634 skip_all,
635 fields(
636 db.query.text,
637 %upstream_oauth_provider.id,
638 ),
639 err,
640 )]
641 async fn disable(
642 &mut self,
643 clock: &dyn Clock,
644 mut upstream_oauth_provider: UpstreamOAuthProvider,
645 ) -> Result<UpstreamOAuthProvider, Self::Error> {
646 let disabled_at = clock.now();
647 let res = sqlx::query!(
648 r#"
649 UPDATE upstream_oauth_providers
650 SET disabled_at = $2
651 WHERE upstream_oauth_provider_id = $1
652 "#,
653 Uuid::from(upstream_oauth_provider.id),
654 disabled_at,
655 )
656 .traced()
657 .execute(&mut *self.conn)
658 .await?;
659
660 DatabaseError::ensure_affected_rows(&res, 1)?;
661
662 upstream_oauth_provider.disabled_at = Some(disabled_at);
663
664 Ok(upstream_oauth_provider)
665 }
666
667 #[tracing::instrument(
668 name = "db.upstream_oauth_provider.list",
669 skip_all,
670 fields(
671 db.query.text,
672 ),
673 err,
674 )]
675 async fn list(
676 &mut self,
677 filter: UpstreamOAuthProviderFilter<'_>,
678 pagination: Pagination,
679 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
680 let (sql, arguments) = Query::select()
681 .expr_as(
682 Expr::col((
683 UpstreamOAuthProviders::Table,
684 UpstreamOAuthProviders::UpstreamOAuthProviderId,
685 )),
686 ProviderLookupIden::UpstreamOauthProviderId,
687 )
688 .expr_as(
689 Expr::col((
690 UpstreamOAuthProviders::Table,
691 UpstreamOAuthProviders::Issuer,
692 )),
693 ProviderLookupIden::Issuer,
694 )
695 .expr_as(
696 Expr::col((
697 UpstreamOAuthProviders::Table,
698 UpstreamOAuthProviders::HumanName,
699 )),
700 ProviderLookupIden::HumanName,
701 )
702 .expr_as(
703 Expr::col((
704 UpstreamOAuthProviders::Table,
705 UpstreamOAuthProviders::BrandName,
706 )),
707 ProviderLookupIden::BrandName,
708 )
709 .expr_as(
710 Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
711 ProviderLookupIden::Scope,
712 )
713 .expr_as(
714 Expr::col((
715 UpstreamOAuthProviders::Table,
716 UpstreamOAuthProviders::ClientId,
717 )),
718 ProviderLookupIden::ClientId,
719 )
720 .expr_as(
721 Expr::col((
722 UpstreamOAuthProviders::Table,
723 UpstreamOAuthProviders::EncryptedClientSecret,
724 )),
725 ProviderLookupIden::EncryptedClientSecret,
726 )
727 .expr_as(
728 Expr::col((
729 UpstreamOAuthProviders::Table,
730 UpstreamOAuthProviders::TokenEndpointSigningAlg,
731 )),
732 ProviderLookupIden::TokenEndpointSigningAlg,
733 )
734 .expr_as(
735 Expr::col((
736 UpstreamOAuthProviders::Table,
737 UpstreamOAuthProviders::TokenEndpointAuthMethod,
738 )),
739 ProviderLookupIden::TokenEndpointAuthMethod,
740 )
741 .expr_as(
742 Expr::col((
743 UpstreamOAuthProviders::Table,
744 UpstreamOAuthProviders::IdTokenSignedResponseAlg,
745 )),
746 ProviderLookupIden::IdTokenSignedResponseAlg,
747 )
748 .expr_as(
749 Expr::col((
750 UpstreamOAuthProviders::Table,
751 UpstreamOAuthProviders::FetchUserinfo,
752 )),
753 ProviderLookupIden::FetchUserinfo,
754 )
755 .expr_as(
756 Expr::col((
757 UpstreamOAuthProviders::Table,
758 UpstreamOAuthProviders::UserinfoSignedResponseAlg,
759 )),
760 ProviderLookupIden::UserinfoSignedResponseAlg,
761 )
762 .expr_as(
763 Expr::col((
764 UpstreamOAuthProviders::Table,
765 UpstreamOAuthProviders::CreatedAt,
766 )),
767 ProviderLookupIden::CreatedAt,
768 )
769 .expr_as(
770 Expr::col((
771 UpstreamOAuthProviders::Table,
772 UpstreamOAuthProviders::DisabledAt,
773 )),
774 ProviderLookupIden::DisabledAt,
775 )
776 .expr_as(
777 Expr::col((
778 UpstreamOAuthProviders::Table,
779 UpstreamOAuthProviders::ClaimsImports,
780 )),
781 ProviderLookupIden::ClaimsImports,
782 )
783 .expr_as(
784 Expr::col((
785 UpstreamOAuthProviders::Table,
786 UpstreamOAuthProviders::JwksUriOverride,
787 )),
788 ProviderLookupIden::JwksUriOverride,
789 )
790 .expr_as(
791 Expr::col((
792 UpstreamOAuthProviders::Table,
793 UpstreamOAuthProviders::TokenEndpointOverride,
794 )),
795 ProviderLookupIden::TokenEndpointOverride,
796 )
797 .expr_as(
798 Expr::col((
799 UpstreamOAuthProviders::Table,
800 UpstreamOAuthProviders::AuthorizationEndpointOverride,
801 )),
802 ProviderLookupIden::AuthorizationEndpointOverride,
803 )
804 .expr_as(
805 Expr::col((
806 UpstreamOAuthProviders::Table,
807 UpstreamOAuthProviders::UserinfoEndpointOverride,
808 )),
809 ProviderLookupIden::UserinfoEndpointOverride,
810 )
811 .expr_as(
812 Expr::col((
813 UpstreamOAuthProviders::Table,
814 UpstreamOAuthProviders::DiscoveryMode,
815 )),
816 ProviderLookupIden::DiscoveryMode,
817 )
818 .expr_as(
819 Expr::col((
820 UpstreamOAuthProviders::Table,
821 UpstreamOAuthProviders::PkceMode,
822 )),
823 ProviderLookupIden::PkceMode,
824 )
825 .expr_as(
826 Expr::col((
827 UpstreamOAuthProviders::Table,
828 UpstreamOAuthProviders::ResponseMode,
829 )),
830 ProviderLookupIden::ResponseMode,
831 )
832 .expr_as(
833 Expr::col((
834 UpstreamOAuthProviders::Table,
835 UpstreamOAuthProviders::AdditionalParameters,
836 )),
837 ProviderLookupIden::AdditionalParameters,
838 )
839 .expr_as(
840 Expr::col((
841 UpstreamOAuthProviders::Table,
842 UpstreamOAuthProviders::ForwardLoginHint,
843 )),
844 ProviderLookupIden::ForwardLoginHint,
845 )
846 .from(UpstreamOAuthProviders::Table)
847 .apply_filter(filter)
848 .generate_pagination(
849 (
850 UpstreamOAuthProviders::Table,
851 UpstreamOAuthProviders::UpstreamOAuthProviderId,
852 ),
853 pagination,
854 )
855 .build_sqlx(PostgresQueryBuilder);
856
857 let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
858 .traced()
859 .fetch_all(&mut *self.conn)
860 .await?;
861
862 let page = pagination
863 .process(edges)
864 .try_map(UpstreamOAuthProvider::try_from)?;
865
866 return Ok(page);
867 }
868
869 #[tracing::instrument(
870 name = "db.upstream_oauth_provider.count",
871 skip_all,
872 fields(
873 db.query.text,
874 ),
875 err,
876 )]
877 async fn count(
878 &mut self,
879 filter: UpstreamOAuthProviderFilter<'_>,
880 ) -> Result<usize, Self::Error> {
881 let (sql, arguments) = Query::select()
882 .expr(
883 Expr::col((
884 UpstreamOAuthProviders::Table,
885 UpstreamOAuthProviders::UpstreamOAuthProviderId,
886 ))
887 .count(),
888 )
889 .from(UpstreamOAuthProviders::Table)
890 .apply_filter(filter)
891 .build_sqlx(PostgresQueryBuilder);
892
893 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
894 .traced()
895 .fetch_one(&mut *self.conn)
896 .await?;
897
898 count
899 .try_into()
900 .map_err(DatabaseError::to_invalid_operation)
901 }
902
903 #[tracing::instrument(
904 name = "db.upstream_oauth_provider.all_enabled",
905 skip_all,
906 fields(
907 db.query.text,
908 ),
909 err,
910 )]
911 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
912 let res = sqlx::query_as!(
913 ProviderLookup,
914 r#"
915 SELECT
916 upstream_oauth_provider_id,
917 issuer,
918 human_name,
919 brand_name,
920 scope,
921 client_id,
922 encrypted_client_secret,
923 token_endpoint_signing_alg,
924 token_endpoint_auth_method,
925 id_token_signed_response_alg,
926 fetch_userinfo,
927 userinfo_signed_response_alg,
928 created_at,
929 disabled_at,
930 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
931 jwks_uri_override,
932 authorization_endpoint_override,
933 token_endpoint_override,
934 userinfo_endpoint_override,
935 discovery_mode,
936 pkce_mode,
937 response_mode,
938 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
939 forward_login_hint
940 FROM upstream_oauth_providers
941 WHERE disabled_at IS NULL
942 ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
943 "#,
944 )
945 .traced()
946 .fetch_all(&mut *self.conn)
947 .await?;
948
949 let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
950 Ok(res?)
951 }
952}