mas_storage_pg/upstream_oauth2/
provider.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11    Clock, Page, Pagination,
12    upstream_oauth2::{
13        UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
14    },
15};
16use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::{PgConnection, types::Json};
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26    DatabaseError, DatabaseInconsistencyError,
27    filter::{Filter, StatementExt},
28    iden::UpstreamOAuthProviders,
29    pagination::QueryBuilderExt,
30    tracing::ExecuteExt,
31};
32
33/// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL
34/// connection
35pub struct PgUpstreamOAuthProviderRepository<'c> {
36    conn: &'c mut PgConnection,
37}
38
39impl<'c> PgUpstreamOAuthProviderRepository<'c> {
40    /// Create a new [`PgUpstreamOAuthProviderRepository`] from an active
41    /// PostgreSQL connection
42    pub fn new(conn: &'c mut PgConnection) -> Self {
43        Self { conn }
44    }
45}
46
47#[derive(sqlx::FromRow)]
48#[enum_def]
49struct ProviderLookup {
50    upstream_oauth_provider_id: Uuid,
51    issuer: Option<String>,
52    human_name: Option<String>,
53    brand_name: Option<String>,
54    scope: String,
55    client_id: String,
56    encrypted_client_secret: Option<String>,
57    token_endpoint_signing_alg: Option<String>,
58    token_endpoint_auth_method: String,
59    id_token_signed_response_alg: String,
60    fetch_userinfo: bool,
61    userinfo_signed_response_alg: Option<String>,
62    created_at: DateTime<Utc>,
63    disabled_at: Option<DateTime<Utc>>,
64    claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
65    jwks_uri_override: Option<String>,
66    authorization_endpoint_override: Option<String>,
67    token_endpoint_override: Option<String>,
68    userinfo_endpoint_override: Option<String>,
69    discovery_mode: String,
70    pkce_mode: String,
71    response_mode: Option<String>,
72    additional_parameters: Option<Json<Vec<(String, String)>>>,
73    forward_login_hint: bool,
74}
75
76impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
77    type Error = DatabaseInconsistencyError;
78
79    #[allow(clippy::too_many_lines)]
80    fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
81        let id = value.upstream_oauth_provider_id.into();
82        let scope = value.scope.parse().map_err(|e| {
83            DatabaseInconsistencyError::on("upstream_oauth_providers")
84                .column("scope")
85                .row(id)
86                .source(e)
87        })?;
88        let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
89            DatabaseInconsistencyError::on("upstream_oauth_providers")
90                .column("token_endpoint_auth_method")
91                .row(id)
92                .source(e)
93        })?;
94        let token_endpoint_signing_alg = value
95            .token_endpoint_signing_alg
96            .map(|x| x.parse())
97            .transpose()
98            .map_err(|e| {
99                DatabaseInconsistencyError::on("upstream_oauth_providers")
100                    .column("token_endpoint_signing_alg")
101                    .row(id)
102                    .source(e)
103            })?;
104        let id_token_signed_response_alg =
105            value.id_token_signed_response_alg.parse().map_err(|e| {
106                DatabaseInconsistencyError::on("upstream_oauth_providers")
107                    .column("id_token_signed_response_alg")
108                    .row(id)
109                    .source(e)
110            })?;
111
112        let userinfo_signed_response_alg = value
113            .userinfo_signed_response_alg
114            .map(|x| x.parse())
115            .transpose()
116            .map_err(|e| {
117                DatabaseInconsistencyError::on("upstream_oauth_providers")
118                    .column("userinfo_signed_response_alg")
119                    .row(id)
120                    .source(e)
121            })?;
122
123        let authorization_endpoint_override = value
124            .authorization_endpoint_override
125            .map(|x| x.parse())
126            .transpose()
127            .map_err(|e| {
128                DatabaseInconsistencyError::on("upstream_oauth_providers")
129                    .column("authorization_endpoint_override")
130                    .row(id)
131                    .source(e)
132            })?;
133
134        let token_endpoint_override = value
135            .token_endpoint_override
136            .map(|x| x.parse())
137            .transpose()
138            .map_err(|e| {
139                DatabaseInconsistencyError::on("upstream_oauth_providers")
140                    .column("token_endpoint_override")
141                    .row(id)
142                    .source(e)
143            })?;
144
145        let userinfo_endpoint_override = value
146            .userinfo_endpoint_override
147            .map(|x| x.parse())
148            .transpose()
149            .map_err(|e| {
150                DatabaseInconsistencyError::on("upstream_oauth_providers")
151                    .column("userinfo_endpoint_override")
152                    .row(id)
153                    .source(e)
154            })?;
155
156        let jwks_uri_override = value
157            .jwks_uri_override
158            .map(|x| x.parse())
159            .transpose()
160            .map_err(|e| {
161                DatabaseInconsistencyError::on("upstream_oauth_providers")
162                    .column("jwks_uri_override")
163                    .row(id)
164                    .source(e)
165            })?;
166
167        let discovery_mode = value.discovery_mode.parse().map_err(|e| {
168            DatabaseInconsistencyError::on("upstream_oauth_providers")
169                .column("discovery_mode")
170                .row(id)
171                .source(e)
172        })?;
173
174        let pkce_mode = value.pkce_mode.parse().map_err(|e| {
175            DatabaseInconsistencyError::on("upstream_oauth_providers")
176                .column("pkce_mode")
177                .row(id)
178                .source(e)
179        })?;
180
181        let response_mode = value
182            .response_mode
183            .map(|x| x.parse())
184            .transpose()
185            .map_err(|e| {
186                DatabaseInconsistencyError::on("upstream_oauth_providers")
187                    .column("response_mode")
188                    .row(id)
189                    .source(e)
190            })?;
191
192        let additional_authorization_parameters = value
193            .additional_parameters
194            .map(|Json(x)| x)
195            .unwrap_or_default();
196
197        Ok(UpstreamOAuthProvider {
198            id,
199            issuer: value.issuer,
200            human_name: value.human_name,
201            brand_name: value.brand_name,
202            scope,
203            client_id: value.client_id,
204            encrypted_client_secret: value.encrypted_client_secret,
205            token_endpoint_auth_method,
206            token_endpoint_signing_alg,
207            id_token_signed_response_alg,
208            fetch_userinfo: value.fetch_userinfo,
209            userinfo_signed_response_alg,
210            created_at: value.created_at,
211            disabled_at: value.disabled_at,
212            claims_imports: value.claims_imports.0,
213            authorization_endpoint_override,
214            token_endpoint_override,
215            userinfo_endpoint_override,
216            jwks_uri_override,
217            discovery_mode,
218            pkce_mode,
219            response_mode,
220            additional_authorization_parameters,
221            forward_login_hint: value.forward_login_hint,
222        })
223    }
224}
225
226impl Filter for UpstreamOAuthProviderFilter<'_> {
227    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
228        sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
229            Expr::col((
230                UpstreamOAuthProviders::Table,
231                UpstreamOAuthProviders::DisabledAt,
232            ))
233            .is_null()
234            .eq(enabled)
235        }))
236    }
237}
238
239#[async_trait]
240impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
241    type Error = DatabaseError;
242
243    #[tracing::instrument(
244        name = "db.upstream_oauth_provider.lookup",
245        skip_all,
246        fields(
247            db.query.text,
248            upstream_oauth_provider.id = %id,
249        ),
250        err,
251    )]
252    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
253        let res = sqlx::query_as!(
254            ProviderLookup,
255            r#"
256                SELECT
257                    upstream_oauth_provider_id,
258                    issuer,
259                    human_name,
260                    brand_name,
261                    scope,
262                    client_id,
263                    encrypted_client_secret,
264                    token_endpoint_signing_alg,
265                    token_endpoint_auth_method,
266                    id_token_signed_response_alg,
267                    fetch_userinfo,
268                    userinfo_signed_response_alg,
269                    created_at,
270                    disabled_at,
271                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
272                    jwks_uri_override,
273                    authorization_endpoint_override,
274                    token_endpoint_override,
275                    userinfo_endpoint_override,
276                    discovery_mode,
277                    pkce_mode,
278                    response_mode,
279                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
280                    forward_login_hint
281                FROM upstream_oauth_providers
282                WHERE upstream_oauth_provider_id = $1
283            "#,
284            Uuid::from(id),
285        )
286        .traced()
287        .fetch_optional(&mut *self.conn)
288        .await?;
289
290        let res = res
291            .map(UpstreamOAuthProvider::try_from)
292            .transpose()
293            .map_err(DatabaseError::from)?;
294
295        Ok(res)
296    }
297
298    #[tracing::instrument(
299        name = "db.upstream_oauth_provider.add",
300        skip_all,
301        fields(
302            db.query.text,
303            upstream_oauth_provider.id,
304            upstream_oauth_provider.issuer = params.issuer,
305            upstream_oauth_provider.client_id = %params.client_id,
306        ),
307        err,
308    )]
309    async fn add(
310        &mut self,
311        rng: &mut (dyn RngCore + Send),
312        clock: &dyn Clock,
313        params: UpstreamOAuthProviderParams,
314    ) -> Result<UpstreamOAuthProvider, Self::Error> {
315        let created_at = clock.now();
316        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
317        tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
318
319        sqlx::query!(
320            r#"
321            INSERT INTO upstream_oauth_providers (
322                upstream_oauth_provider_id,
323                issuer,
324                human_name,
325                brand_name,
326                scope,
327                token_endpoint_auth_method,
328                token_endpoint_signing_alg,
329                id_token_signed_response_alg,
330                fetch_userinfo,
331                userinfo_signed_response_alg,
332                client_id,
333                encrypted_client_secret,
334                claims_imports,
335                authorization_endpoint_override,
336                token_endpoint_override,
337                userinfo_endpoint_override,
338                jwks_uri_override,
339                discovery_mode,
340                pkce_mode,
341                response_mode,
342                forward_login_hint,
343                created_at
344            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
345                      $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
346        "#,
347            Uuid::from(id),
348            params.issuer.as_deref(),
349            params.human_name.as_deref(),
350            params.brand_name.as_deref(),
351            params.scope.to_string(),
352            params.token_endpoint_auth_method.to_string(),
353            params
354                .token_endpoint_signing_alg
355                .as_ref()
356                .map(ToString::to_string),
357            params.id_token_signed_response_alg.to_string(),
358            params.fetch_userinfo,
359            params
360                .userinfo_signed_response_alg
361                .as_ref()
362                .map(ToString::to_string),
363            &params.client_id,
364            params.encrypted_client_secret.as_deref(),
365            Json(&params.claims_imports) as _,
366            params
367                .authorization_endpoint_override
368                .as_ref()
369                .map(ToString::to_string),
370            params
371                .token_endpoint_override
372                .as_ref()
373                .map(ToString::to_string),
374            params
375                .userinfo_endpoint_override
376                .as_ref()
377                .map(ToString::to_string),
378            params.jwks_uri_override.as_ref().map(ToString::to_string),
379            params.discovery_mode.as_str(),
380            params.pkce_mode.as_str(),
381            params.response_mode.as_ref().map(ToString::to_string),
382            params.forward_login_hint,
383            created_at,
384        )
385        .traced()
386        .execute(&mut *self.conn)
387        .await?;
388
389        Ok(UpstreamOAuthProvider {
390            id,
391            issuer: params.issuer,
392            human_name: params.human_name,
393            brand_name: params.brand_name,
394            scope: params.scope,
395            client_id: params.client_id,
396            encrypted_client_secret: params.encrypted_client_secret,
397            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
398            token_endpoint_auth_method: params.token_endpoint_auth_method,
399            id_token_signed_response_alg: params.id_token_signed_response_alg,
400            fetch_userinfo: params.fetch_userinfo,
401            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
402            created_at,
403            disabled_at: None,
404            claims_imports: params.claims_imports,
405            authorization_endpoint_override: params.authorization_endpoint_override,
406            token_endpoint_override: params.token_endpoint_override,
407            userinfo_endpoint_override: params.userinfo_endpoint_override,
408            jwks_uri_override: params.jwks_uri_override,
409            discovery_mode: params.discovery_mode,
410            pkce_mode: params.pkce_mode,
411            response_mode: params.response_mode,
412            additional_authorization_parameters: params.additional_authorization_parameters,
413            forward_login_hint: params.forward_login_hint,
414        })
415    }
416
417    #[tracing::instrument(
418        name = "db.upstream_oauth_provider.delete_by_id",
419        skip_all,
420        fields(
421            db.query.text,
422            upstream_oauth_provider.id = %id,
423        ),
424        err,
425    )]
426    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
427        // Delete the authorization sessions first, as they have a foreign key
428        // constraint on the links and the providers.
429        {
430            let span = info_span!(
431                "db.oauth2_client.delete_by_id.authorization_sessions",
432                upstream_oauth_provider.id = %id,
433                { DB_QUERY_TEXT } = tracing::field::Empty,
434            );
435            sqlx::query!(
436                r#"
437                    DELETE FROM upstream_oauth_authorization_sessions
438                    WHERE upstream_oauth_provider_id = $1
439                "#,
440                Uuid::from(id),
441            )
442            .record(&span)
443            .execute(&mut *self.conn)
444            .instrument(span)
445            .await?;
446        }
447
448        // Delete the links next, as they have a foreign key constraint on the
449        // providers.
450        {
451            let span = info_span!(
452                "db.oauth2_client.delete_by_id.links",
453                upstream_oauth_provider.id = %id,
454                { DB_QUERY_TEXT } = tracing::field::Empty,
455            );
456            sqlx::query!(
457                r#"
458                    DELETE FROM upstream_oauth_links
459                    WHERE upstream_oauth_provider_id = $1
460                "#,
461                Uuid::from(id),
462            )
463            .record(&span)
464            .execute(&mut *self.conn)
465            .instrument(span)
466            .await?;
467        }
468
469        let res = sqlx::query!(
470            r#"
471                DELETE FROM upstream_oauth_providers
472                WHERE upstream_oauth_provider_id = $1
473            "#,
474            Uuid::from(id),
475        )
476        .traced()
477        .execute(&mut *self.conn)
478        .await?;
479
480        DatabaseError::ensure_affected_rows(&res, 1)
481    }
482
483    #[tracing::instrument(
484        name = "db.upstream_oauth_provider.add",
485        skip_all,
486        fields(
487            db.query.text,
488            upstream_oauth_provider.id = %id,
489            upstream_oauth_provider.issuer = params.issuer,
490            upstream_oauth_provider.client_id = %params.client_id,
491        ),
492        err,
493    )]
494    async fn upsert(
495        &mut self,
496        clock: &dyn Clock,
497        id: Ulid,
498        params: UpstreamOAuthProviderParams,
499    ) -> Result<UpstreamOAuthProvider, Self::Error> {
500        let created_at = clock.now();
501
502        let created_at = sqlx::query_scalar!(
503            r#"
504                INSERT INTO upstream_oauth_providers (
505                    upstream_oauth_provider_id,
506                    issuer,
507                    human_name,
508                    brand_name,
509                    scope,
510                    token_endpoint_auth_method,
511                    token_endpoint_signing_alg,
512                    id_token_signed_response_alg,
513                    fetch_userinfo,
514                    userinfo_signed_response_alg,
515                    client_id,
516                    encrypted_client_secret,
517                    claims_imports,
518                    authorization_endpoint_override,
519                    token_endpoint_override,
520                    userinfo_endpoint_override,
521                    jwks_uri_override,
522                    discovery_mode,
523                    pkce_mode,
524                    response_mode,
525                    additional_parameters,
526                    forward_login_hint,
527                    ui_order,
528                    created_at
529                ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
530                          $12, $13, $14, $15, $16, $17, $18, $19, $20,
531                          $21, $22, $23, $24)
532                ON CONFLICT (upstream_oauth_provider_id)
533                    DO UPDATE
534                    SET
535                        issuer = EXCLUDED.issuer,
536                        human_name = EXCLUDED.human_name,
537                        brand_name = EXCLUDED.brand_name,
538                        scope = EXCLUDED.scope,
539                        token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
540                        token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
541                        id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
542                        fetch_userinfo = EXCLUDED.fetch_userinfo,
543                        userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
544                        disabled_at = NULL,
545                        client_id = EXCLUDED.client_id,
546                        encrypted_client_secret = EXCLUDED.encrypted_client_secret,
547                        claims_imports = EXCLUDED.claims_imports,
548                        authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
549                        token_endpoint_override = EXCLUDED.token_endpoint_override,
550                        userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
551                        jwks_uri_override = EXCLUDED.jwks_uri_override,
552                        discovery_mode = EXCLUDED.discovery_mode,
553                        pkce_mode = EXCLUDED.pkce_mode,
554                        response_mode = EXCLUDED.response_mode,
555                        additional_parameters = EXCLUDED.additional_parameters,
556                        forward_login_hint = EXCLUDED.forward_login_hint,
557                        ui_order = EXCLUDED.ui_order
558                RETURNING created_at
559            "#,
560            Uuid::from(id),
561            params.issuer.as_deref(),
562            params.human_name.as_deref(),
563            params.brand_name.as_deref(),
564            params.scope.to_string(),
565            params.token_endpoint_auth_method.to_string(),
566            params
567                .token_endpoint_signing_alg
568                .as_ref()
569                .map(ToString::to_string),
570            params.id_token_signed_response_alg.to_string(),
571            params.fetch_userinfo,
572            params
573                .userinfo_signed_response_alg
574                .as_ref()
575                .map(ToString::to_string),
576            &params.client_id,
577            params.encrypted_client_secret.as_deref(),
578            Json(&params.claims_imports) as _,
579            params
580                .authorization_endpoint_override
581                .as_ref()
582                .map(ToString::to_string),
583            params
584                .token_endpoint_override
585                .as_ref()
586                .map(ToString::to_string),
587            params
588                .userinfo_endpoint_override
589                .as_ref()
590                .map(ToString::to_string),
591            params.jwks_uri_override.as_ref().map(ToString::to_string),
592            params.discovery_mode.as_str(),
593            params.pkce_mode.as_str(),
594            params.response_mode.as_ref().map(ToString::to_string),
595            Json(&params.additional_authorization_parameters) as _,
596            params.forward_login_hint,
597            params.ui_order,
598            created_at,
599        )
600        .traced()
601        .fetch_one(&mut *self.conn)
602        .await?;
603
604        Ok(UpstreamOAuthProvider {
605            id,
606            issuer: params.issuer,
607            human_name: params.human_name,
608            brand_name: params.brand_name,
609            scope: params.scope,
610            client_id: params.client_id,
611            encrypted_client_secret: params.encrypted_client_secret,
612            token_endpoint_signing_alg: params.token_endpoint_signing_alg,
613            token_endpoint_auth_method: params.token_endpoint_auth_method,
614            id_token_signed_response_alg: params.id_token_signed_response_alg,
615            fetch_userinfo: params.fetch_userinfo,
616            userinfo_signed_response_alg: params.userinfo_signed_response_alg,
617            created_at,
618            disabled_at: None,
619            claims_imports: params.claims_imports,
620            authorization_endpoint_override: params.authorization_endpoint_override,
621            token_endpoint_override: params.token_endpoint_override,
622            userinfo_endpoint_override: params.userinfo_endpoint_override,
623            jwks_uri_override: params.jwks_uri_override,
624            discovery_mode: params.discovery_mode,
625            pkce_mode: params.pkce_mode,
626            response_mode: params.response_mode,
627            additional_authorization_parameters: params.additional_authorization_parameters,
628            forward_login_hint: params.forward_login_hint,
629        })
630    }
631
632    #[tracing::instrument(
633        name = "db.upstream_oauth_provider.disable",
634        skip_all,
635        fields(
636            db.query.text,
637            %upstream_oauth_provider.id,
638        ),
639        err,
640    )]
641    async fn disable(
642        &mut self,
643        clock: &dyn Clock,
644        mut upstream_oauth_provider: UpstreamOAuthProvider,
645    ) -> Result<UpstreamOAuthProvider, Self::Error> {
646        let disabled_at = clock.now();
647        let res = sqlx::query!(
648            r#"
649                UPDATE upstream_oauth_providers
650                SET disabled_at = $2
651                WHERE upstream_oauth_provider_id = $1
652            "#,
653            Uuid::from(upstream_oauth_provider.id),
654            disabled_at,
655        )
656        .traced()
657        .execute(&mut *self.conn)
658        .await?;
659
660        DatabaseError::ensure_affected_rows(&res, 1)?;
661
662        upstream_oauth_provider.disabled_at = Some(disabled_at);
663
664        Ok(upstream_oauth_provider)
665    }
666
667    #[tracing::instrument(
668        name = "db.upstream_oauth_provider.list",
669        skip_all,
670        fields(
671            db.query.text,
672        ),
673        err,
674    )]
675    async fn list(
676        &mut self,
677        filter: UpstreamOAuthProviderFilter<'_>,
678        pagination: Pagination,
679    ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
680        let (sql, arguments) = Query::select()
681            .expr_as(
682                Expr::col((
683                    UpstreamOAuthProviders::Table,
684                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
685                )),
686                ProviderLookupIden::UpstreamOauthProviderId,
687            )
688            .expr_as(
689                Expr::col((
690                    UpstreamOAuthProviders::Table,
691                    UpstreamOAuthProviders::Issuer,
692                )),
693                ProviderLookupIden::Issuer,
694            )
695            .expr_as(
696                Expr::col((
697                    UpstreamOAuthProviders::Table,
698                    UpstreamOAuthProviders::HumanName,
699                )),
700                ProviderLookupIden::HumanName,
701            )
702            .expr_as(
703                Expr::col((
704                    UpstreamOAuthProviders::Table,
705                    UpstreamOAuthProviders::BrandName,
706                )),
707                ProviderLookupIden::BrandName,
708            )
709            .expr_as(
710                Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
711                ProviderLookupIden::Scope,
712            )
713            .expr_as(
714                Expr::col((
715                    UpstreamOAuthProviders::Table,
716                    UpstreamOAuthProviders::ClientId,
717                )),
718                ProviderLookupIden::ClientId,
719            )
720            .expr_as(
721                Expr::col((
722                    UpstreamOAuthProviders::Table,
723                    UpstreamOAuthProviders::EncryptedClientSecret,
724                )),
725                ProviderLookupIden::EncryptedClientSecret,
726            )
727            .expr_as(
728                Expr::col((
729                    UpstreamOAuthProviders::Table,
730                    UpstreamOAuthProviders::TokenEndpointSigningAlg,
731                )),
732                ProviderLookupIden::TokenEndpointSigningAlg,
733            )
734            .expr_as(
735                Expr::col((
736                    UpstreamOAuthProviders::Table,
737                    UpstreamOAuthProviders::TokenEndpointAuthMethod,
738                )),
739                ProviderLookupIden::TokenEndpointAuthMethod,
740            )
741            .expr_as(
742                Expr::col((
743                    UpstreamOAuthProviders::Table,
744                    UpstreamOAuthProviders::IdTokenSignedResponseAlg,
745                )),
746                ProviderLookupIden::IdTokenSignedResponseAlg,
747            )
748            .expr_as(
749                Expr::col((
750                    UpstreamOAuthProviders::Table,
751                    UpstreamOAuthProviders::FetchUserinfo,
752                )),
753                ProviderLookupIden::FetchUserinfo,
754            )
755            .expr_as(
756                Expr::col((
757                    UpstreamOAuthProviders::Table,
758                    UpstreamOAuthProviders::UserinfoSignedResponseAlg,
759                )),
760                ProviderLookupIden::UserinfoSignedResponseAlg,
761            )
762            .expr_as(
763                Expr::col((
764                    UpstreamOAuthProviders::Table,
765                    UpstreamOAuthProviders::CreatedAt,
766                )),
767                ProviderLookupIden::CreatedAt,
768            )
769            .expr_as(
770                Expr::col((
771                    UpstreamOAuthProviders::Table,
772                    UpstreamOAuthProviders::DisabledAt,
773                )),
774                ProviderLookupIden::DisabledAt,
775            )
776            .expr_as(
777                Expr::col((
778                    UpstreamOAuthProviders::Table,
779                    UpstreamOAuthProviders::ClaimsImports,
780                )),
781                ProviderLookupIden::ClaimsImports,
782            )
783            .expr_as(
784                Expr::col((
785                    UpstreamOAuthProviders::Table,
786                    UpstreamOAuthProviders::JwksUriOverride,
787                )),
788                ProviderLookupIden::JwksUriOverride,
789            )
790            .expr_as(
791                Expr::col((
792                    UpstreamOAuthProviders::Table,
793                    UpstreamOAuthProviders::TokenEndpointOverride,
794                )),
795                ProviderLookupIden::TokenEndpointOverride,
796            )
797            .expr_as(
798                Expr::col((
799                    UpstreamOAuthProviders::Table,
800                    UpstreamOAuthProviders::AuthorizationEndpointOverride,
801                )),
802                ProviderLookupIden::AuthorizationEndpointOverride,
803            )
804            .expr_as(
805                Expr::col((
806                    UpstreamOAuthProviders::Table,
807                    UpstreamOAuthProviders::UserinfoEndpointOverride,
808                )),
809                ProviderLookupIden::UserinfoEndpointOverride,
810            )
811            .expr_as(
812                Expr::col((
813                    UpstreamOAuthProviders::Table,
814                    UpstreamOAuthProviders::DiscoveryMode,
815                )),
816                ProviderLookupIden::DiscoveryMode,
817            )
818            .expr_as(
819                Expr::col((
820                    UpstreamOAuthProviders::Table,
821                    UpstreamOAuthProviders::PkceMode,
822                )),
823                ProviderLookupIden::PkceMode,
824            )
825            .expr_as(
826                Expr::col((
827                    UpstreamOAuthProviders::Table,
828                    UpstreamOAuthProviders::ResponseMode,
829                )),
830                ProviderLookupIden::ResponseMode,
831            )
832            .expr_as(
833                Expr::col((
834                    UpstreamOAuthProviders::Table,
835                    UpstreamOAuthProviders::AdditionalParameters,
836                )),
837                ProviderLookupIden::AdditionalParameters,
838            )
839            .expr_as(
840                Expr::col((
841                    UpstreamOAuthProviders::Table,
842                    UpstreamOAuthProviders::ForwardLoginHint,
843                )),
844                ProviderLookupIden::ForwardLoginHint,
845            )
846            .from(UpstreamOAuthProviders::Table)
847            .apply_filter(filter)
848            .generate_pagination(
849                (
850                    UpstreamOAuthProviders::Table,
851                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
852                ),
853                pagination,
854            )
855            .build_sqlx(PostgresQueryBuilder);
856
857        let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
858            .traced()
859            .fetch_all(&mut *self.conn)
860            .await?;
861
862        let page = pagination
863            .process(edges)
864            .try_map(UpstreamOAuthProvider::try_from)?;
865
866        return Ok(page);
867    }
868
869    #[tracing::instrument(
870        name = "db.upstream_oauth_provider.count",
871        skip_all,
872        fields(
873            db.query.text,
874        ),
875        err,
876    )]
877    async fn count(
878        &mut self,
879        filter: UpstreamOAuthProviderFilter<'_>,
880    ) -> Result<usize, Self::Error> {
881        let (sql, arguments) = Query::select()
882            .expr(
883                Expr::col((
884                    UpstreamOAuthProviders::Table,
885                    UpstreamOAuthProviders::UpstreamOAuthProviderId,
886                ))
887                .count(),
888            )
889            .from(UpstreamOAuthProviders::Table)
890            .apply_filter(filter)
891            .build_sqlx(PostgresQueryBuilder);
892
893        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
894            .traced()
895            .fetch_one(&mut *self.conn)
896            .await?;
897
898        count
899            .try_into()
900            .map_err(DatabaseError::to_invalid_operation)
901    }
902
903    #[tracing::instrument(
904        name = "db.upstream_oauth_provider.all_enabled",
905        skip_all,
906        fields(
907            db.query.text,
908        ),
909        err,
910    )]
911    async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
912        let res = sqlx::query_as!(
913            ProviderLookup,
914            r#"
915                SELECT
916                    upstream_oauth_provider_id,
917                    issuer,
918                    human_name,
919                    brand_name,
920                    scope,
921                    client_id,
922                    encrypted_client_secret,
923                    token_endpoint_signing_alg,
924                    token_endpoint_auth_method,
925                    id_token_signed_response_alg,
926                    fetch_userinfo,
927                    userinfo_signed_response_alg,
928                    created_at,
929                    disabled_at,
930                    claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
931                    jwks_uri_override,
932                    authorization_endpoint_override,
933                    token_endpoint_override,
934                    userinfo_endpoint_override,
935                    discovery_mode,
936                    pkce_mode,
937                    response_mode,
938                    additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
939                    forward_login_hint
940                FROM upstream_oauth_providers
941                WHERE disabled_at IS NULL
942                ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
943            "#,
944        )
945        .traced()
946        .fetch_all(&mut *self.conn)
947        .await?;
948
949        let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
950        Ok(res?)
951    }
952}