Source code for sirepo.auth_db.user

"""auth database models for user roles.

:copyright: Copyright (c) 2022 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdc, pkdlog, pkdp
import sirepo.auth_db
import sirepo.auth_role
import sirepo.srtime
import sirepo.util
import sqlalchemy


[docs] class UserRegistration(sirepo.auth_db.UserDbBase): __tablename__ = "user_registration_t" uid = sqlalchemy.Column(sirepo.auth_db.STRING_ID, primary_key=True) created = sqlalchemy.Column(sqlalchemy.DateTime(), nullable=False) display_name = sqlalchemy.Column(sirepo.auth_db.STRING_NAME)
[docs] class UserRole(sirepo.auth_db.UserDbBase): __tablename__ = "user_role_t" uid = sqlalchemy.Column(sirepo.auth_db.STRING_ID, primary_key=True) role = sqlalchemy.Column(sirepo.auth_db.STRING_NAME, primary_key=True) expiration = sqlalchemy.Column(sqlalchemy.DateTime())
[docs] def all_roles(self): cls = self.__class__ return [r[0] for r in self.query().distinct(cls.role).all()]
[docs] def add_plan(self, role, uid, expiration=None): # TODO(robnagler) always trust stripe? # Assert role and probably need sanity check... e = sirepo.util.plan_role_expiration(role) if expiration: e = expiration # check active plan, warn and expire them if r := self.unchecked_active_plan(uid): if r.role != role: pkdlog( "user {} had existing active plan: [{}, {}], expiring now before applying new plan: [{}, {}]", uid, r.role, r.expiration, role, e, ) self.expire_role(r.role, uid) self.add_roles([role], uid, expiration=e)
[docs] def add_roles(self, roles, uid, expiration=None): """Add roles or update expiration""" from sirepo import sim_data for r in roles: if len(r) <= 1: raise AssertionError(f"no single letter role={r}") # Check here, because sqlite doesn't throw IntegrityErrors # at the point of the new() operation. if self._has_role(r, uid): self.set_role_expiration(r, uid, expiration) else: self.new(uid=uid, role=r, expiration=expiration).save() sim_data.audit_proprietary_lib_files(qcall=self.auth_db.qcall, uid=uid)
[docs] def delete_roles(self, roles, uid): from sirepo import sim_data cls = self.__class__ self.auth_db.execute( sqlalchemy.delete(cls) .where( cls.uid == uid, ) .where( cls.role.in_(roles), ) ) sim_data.audit_proprietary_lib_files(qcall=self.auth_db.qcall, uid=uid)
[docs] def expire_role(self, role, uid): self.set_role_expiration(role, uid, sirepo.srtime.utc_now())
[docs] def get_roles(self, uid): return self.search_all_for_column("role", uid=uid)
[docs] def get_roles_and_expiration(self, uid): return [ PKDict(role=r.role, expiration=r.expiration) for r in self.query().filter_by(uid=uid) ]
[docs] def has_active_plan(self, uid): return bool(self.unchecked_active_plan(uid))
[docs] def has_active_role(self, role, uid): r = self._has_role(role, uid) return r and not self._is_expired_role(r)
[docs] def has_expired_role(self, role, uid): r = self._has_role(role, uid) return r and self._is_expired_role(r)
[docs] def set_role_expiration(self, role, uid, expiration): r = self.search_by(uid=uid, role=role) r.expiration = expiration r.save()
[docs] def uids_of_paid_users(self): return self.uids_with_roles(sirepo.auth_role.PLAN_ROLES_PAID)
[docs] def uids_with_roles(self, roles): for r in roles: sirepo.auth_role.check(r) cls = self.__class__ return [ x[0] for x in self.query() .with_entities(cls.uid) .filter( cls.role.in_(roles), ) .distinct() .all() ]
[docs] def unchecked_active_plan(self, uid): cls = self.__class__ return ( self.query() .filter( cls.role.in_(sirepo.auth_role.PLAN_ROLES), cls.uid == uid, sqlalchemy.or_( cls.expiration.is_(None), cls.expiration > sirepo.srtime.utc_now(), ), ) .first() )
def _has_role(self, role, uid): return self.unchecked_search_by(uid=uid, role=role) def _is_expired_role(self, role_record): return ( role_record.expiration and role_record.expiration < sirepo.srtime.utc_now() )
[docs] class UserRoleModeration(sirepo.auth_db.UserDbBase): __tablename__ = "user_role_moderation_t" uid = sqlalchemy.Column(sirepo.auth_db.STRING_ID, primary_key=True) role = sqlalchemy.Column(sirepo.auth_db.STRING_NAME, primary_key=True) status = sqlalchemy.Column(sirepo.auth_db.STRING_NAME, nullable=False) moderator_uid = sqlalchemy.Column(sirepo.auth_db.STRING_ID) last_updated = sqlalchemy.Column( sqlalchemy.DateTime(), server_default=sqlalchemy.sql.func.now(), onupdate=sqlalchemy.sql.func.now(), nullable=False, )
[docs] def get_moderation_request_rows(self): cls = self.__class__ e = self.auth_db.model("AuthEmailUser").__class__ q = ( self.auth_db.query(e) .with_entities( e.user_name.label("email"), *cls.__table__.columns, ) .filter( e.uid == cls.uid, sqlalchemy.sql.expression.or_( cls.status == "pending", cls.status == "clarify" ), ) .all() ) return [PKDict(zip(r.keys(), r)) for r in q]
[docs] def get_status(self, role, uid): s = self.unchecked_search_by(uid=uid, role=role) if not s: return None return sirepo.auth_role.ModerationStatus.check(s.status)
[docs] def set_status(self, role, uid, status, moderator_uid): s = self.search_by(uid=uid, role=role) s.status = sirepo.auth_role.ModerationStatus.check(status) if moderator_uid: s.moderator_uid = moderator_uid s.save()