committed by
GitHub
13 changed files with 104 additions and 307 deletions
@ -0,0 +1,55 @@ |
|||||
|
from typing import Any |
||||
|
|
||||
|
from sqlmodel import Session, select |
||||
|
|
||||
|
from app.core.security import get_password_hash, verify_password |
||||
|
from app.models import Item, ItemCreate, User, UserCreate, UserUpdate |
||||
|
|
||||
|
|
||||
|
def create_user(*, session: Session, user_create: UserCreate) -> User: |
||||
|
db_obj = User.model_validate( |
||||
|
user_create, update={"hashed_password": get_password_hash(user_create.password)} |
||||
|
) |
||||
|
session.add(db_obj) |
||||
|
session.commit() |
||||
|
session.refresh(db_obj) |
||||
|
return db_obj |
||||
|
|
||||
|
|
||||
|
def update_user(*, session: Session, user_id: int, user_in: UserUpdate) -> Any: |
||||
|
db_user = session.get(User, user_id) |
||||
|
if not db_user: |
||||
|
return None |
||||
|
user_data = user_in.model_dump(exclude_unset=True) |
||||
|
extra_data = {} |
||||
|
if "password" in user_data: |
||||
|
password = user_data["password"] |
||||
|
hashed_password = get_password_hash(password) |
||||
|
extra_data["hashed_password"] = hashed_password |
||||
|
db_user.sqlmodel_update(user_data, update=extra_data) |
||||
|
session.add(db_user) |
||||
|
session.commit() |
||||
|
session.refresh(db_user) |
||||
|
return db_user |
||||
|
|
||||
|
|
||||
|
def get_user_by_email(*, session: Session, email: str) -> User | None: |
||||
|
statement = select(User).where(User.email == email) |
||||
|
session_user = session.exec(statement).first() |
||||
|
return session_user |
||||
|
|
||||
|
|
||||
|
def authenticate(*, session: Session, email: str, password: str) -> User | None: |
||||
|
db_user = get_user_by_email(session=session, email=email) |
||||
|
if not db_user: |
||||
|
return None |
||||
|
if not verify_password(password, db_user.hashed_password): |
||||
|
return None |
||||
|
return db_user |
||||
|
|
||||
|
def create_item(*, session: Session, item_in: ItemCreate, owner_id: int) -> Item: |
||||
|
db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) |
||||
|
session.add(db_item) |
||||
|
session.commit() |
||||
|
session.refresh(db_item) |
||||
|
return db_item |
@ -1,37 +0,0 @@ |
|||||
# For a new basic set of CRUD operations you could just do |
|
||||
# from .base import CRUDBase |
|
||||
# from app.models.item import Item |
|
||||
# from app.schemas.item import ItemCreate, ItemUpdate |
|
||||
# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item) |
|
||||
from sqlmodel import Session, select |
|
||||
|
|
||||
from app.core.security import get_password_hash, verify_password |
|
||||
from app.models import User, UserCreate |
|
||||
|
|
||||
from .crud_item import item as item |
|
||||
from .crud_user import user as user |
|
||||
|
|
||||
|
|
||||
def create_user(*, session: Session, user_create: UserCreate) -> User: |
|
||||
db_obj = User.from_orm( |
|
||||
user_create, update={"hashed_password": get_password_hash(user_create.password)} |
|
||||
) |
|
||||
session.add(db_obj) |
|
||||
session.commit() |
|
||||
session.refresh(db_obj) |
|
||||
return db_obj |
|
||||
|
|
||||
|
|
||||
def get_user_by_email(*, session: Session, email: str) -> User | None: |
|
||||
statement = select(User).where(User.email == email) |
|
||||
session_user = session.exec(statement).first() |
|
||||
return session_user |
|
||||
|
|
||||
|
|
||||
def authenticate(*, session: Session, email: str, password: str) -> User | None: |
|
||||
db_user = get_user_by_email(session=session, email=email) |
|
||||
if not db_user: |
|
||||
return None |
|
||||
if not verify_password(password, db_user.hashed_password): |
|
||||
return None |
|
||||
return db_user |
|
@ -1,59 +0,0 @@ |
|||||
from typing import Any, Generic, TypeVar |
|
||||
|
|
||||
from fastapi.encoders import jsonable_encoder |
|
||||
from pydantic import BaseModel |
|
||||
from sqlalchemy.orm import Session |
|
||||
|
|
||||
ModelType = TypeVar("ModelType", bound=Any) |
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) |
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) |
|
||||
|
|
||||
|
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): |
|
||||
def __init__(self, model: type[ModelType]): |
|
||||
""" |
|
||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD). |
|
||||
|
|
||||
**Parameters** |
|
||||
|
|
||||
* `model`: A SQLAlchemy model class |
|
||||
* `schema`: A Pydantic model (schema) class |
|
||||
""" |
|
||||
self.model = model |
|
||||
|
|
||||
def get(self, db: Session, id: Any) -> ModelType | None: |
|
||||
return db.query(self.model).filter(self.model.id == id).first() |
|
||||
|
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType: |
|
||||
obj_in_data = jsonable_encoder(obj_in) |
|
||||
db_obj = self.model(**obj_in_data) # type: ignore |
|
||||
db.add(db_obj) |
|
||||
db.commit() |
|
||||
db.refresh(db_obj) |
|
||||
return db_obj |
|
||||
|
|
||||
def update( |
|
||||
self, |
|
||||
db: Session, |
|
||||
*, |
|
||||
db_obj: ModelType, |
|
||||
obj_in: UpdateSchemaType | dict[str, Any], |
|
||||
) -> ModelType: |
|
||||
obj_data = jsonable_encoder(db_obj) |
|
||||
if isinstance(obj_in, dict): |
|
||||
update_data = obj_in |
|
||||
else: |
|
||||
update_data = obj_in.dict(exclude_unset=True) |
|
||||
for field in obj_data: |
|
||||
if field in update_data: |
|
||||
setattr(db_obj, field, update_data[field]) |
|
||||
db.add(db_obj) |
|
||||
db.commit() |
|
||||
db.refresh(db_obj) |
|
||||
return db_obj |
|
||||
|
|
||||
def remove(self, db: Session, *, id: int) -> ModelType: |
|
||||
obj = db.query(self.model).get(id) |
|
||||
db.delete(obj) |
|
||||
db.commit() |
|
||||
return obj |
|
@ -1,32 +0,0 @@ |
|||||
from fastapi.encoders import jsonable_encoder |
|
||||
from sqlalchemy.orm import Session |
|
||||
|
|
||||
from app.crud.base import CRUDBase |
|
||||
from app.models import Item |
|
||||
from app.schemas.item import ItemCreate, ItemUpdate |
|
||||
|
|
||||
|
|
||||
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]): |
|
||||
def create_with_owner( |
|
||||
self, db: Session, *, obj_in: ItemCreate, owner_id: int |
|
||||
) -> Item: |
|
||||
obj_in_data = jsonable_encoder(obj_in) |
|
||||
db_obj = self.model(**obj_in_data, owner_id=owner_id) |
|
||||
db.add(db_obj) |
|
||||
db.commit() |
|
||||
db.refresh(db_obj) |
|
||||
return db_obj |
|
||||
|
|
||||
def get_multi_by_owner( |
|
||||
self, db: Session, *, owner_id: int, skip: int = 0, limit: int = 100 |
|
||||
) -> list[Item]: |
|
||||
return ( |
|
||||
db.query(self.model) |
|
||||
.filter(Item.owner_id == owner_id) |
|
||||
.offset(skip) |
|
||||
.limit(limit) |
|
||||
.all() |
|
||||
) |
|
||||
|
|
||||
|
|
||||
item = CRUDItem(Item) |
|
@ -1,55 +0,0 @@ |
|||||
from typing import Any |
|
||||
|
|
||||
from sqlalchemy.orm import Session |
|
||||
|
|
||||
from app.core.security import get_password_hash, verify_password |
|
||||
from app.crud.base import CRUDBase |
|
||||
from app.models import User |
|
||||
from app.schemas.user import UserCreate, UserUpdate |
|
||||
|
|
||||
|
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): |
|
||||
def get_by_email(self, db: Session, *, email: str) -> User | None: |
|
||||
return db.query(User).filter(User.email == email).first() |
|
||||
|
|
||||
def create(self, db: Session, *, obj_in: UserCreate) -> User: |
|
||||
db_obj = User( |
|
||||
email=obj_in.email, |
|
||||
hashed_password=get_password_hash(obj_in.password), |
|
||||
full_name=obj_in.full_name, |
|
||||
is_superuser=obj_in.is_superuser, |
|
||||
) |
|
||||
db.add(db_obj) |
|
||||
db.commit() |
|
||||
db.refresh(db_obj) |
|
||||
return db_obj |
|
||||
|
|
||||
def update( |
|
||||
self, db: Session, *, db_obj: User, obj_in: UserUpdate | dict[str, Any] |
|
||||
) -> User: |
|
||||
if isinstance(obj_in, dict): |
|
||||
update_data = obj_in |
|
||||
else: |
|
||||
update_data = obj_in.dict(exclude_unset=True) |
|
||||
if update_data["password"]: |
|
||||
hashed_password = get_password_hash(update_data["password"]) |
|
||||
del update_data["password"] |
|
||||
update_data["hashed_password"] = hashed_password |
|
||||
return super().update(db, db_obj=db_obj, obj_in=update_data) |
|
||||
|
|
||||
def authenticate(self, db: Session, *, email: str, password: str) -> User | None: |
|
||||
user = self.get_by_email(db, email=email) |
|
||||
if not user: |
|
||||
return None |
|
||||
if not verify_password(password, user.hashed_password): |
|
||||
return None |
|
||||
return user |
|
||||
|
|
||||
def is_active(self, user: User) -> bool: |
|
||||
return user.is_active |
|
||||
|
|
||||
def is_superuser(self, user: User) -> bool: |
|
||||
return user.is_superuser |
|
||||
|
|
||||
|
|
||||
user = CRUDUser(User) |
|
@ -1,61 +0,0 @@ |
|||||
from sqlalchemy.orm import Session |
|
||||
|
|
||||
from app import crud |
|
||||
from app.schemas.item import ItemCreate, ItemUpdate |
|
||||
from app.tests.utils.user import create_random_user |
|
||||
from app.tests.utils.utils import random_lower_string |
|
||||
|
|
||||
|
|
||||
def test_create_item(db: Session) -> None: |
|
||||
title = random_lower_string() |
|
||||
description = random_lower_string() |
|
||||
item_in = ItemCreate(title=title, description=description) |
|
||||
user = create_random_user(db) |
|
||||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|
||||
assert item.title == title |
|
||||
assert item.description == description |
|
||||
assert item.owner_id == user.id |
|
||||
|
|
||||
|
|
||||
def test_get_item(db: Session) -> None: |
|
||||
title = random_lower_string() |
|
||||
description = random_lower_string() |
|
||||
item_in = ItemCreate(title=title, description=description) |
|
||||
user = create_random_user(db) |
|
||||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|
||||
stored_item = crud.item.get(db=db, id=item.id) |
|
||||
assert stored_item |
|
||||
assert item.id == stored_item.id |
|
||||
assert item.title == stored_item.title |
|
||||
assert item.description == stored_item.description |
|
||||
assert item.owner_id == stored_item.owner_id |
|
||||
|
|
||||
|
|
||||
def test_update_item(db: Session) -> None: |
|
||||
title = random_lower_string() |
|
||||
description = random_lower_string() |
|
||||
item_in = ItemCreate(title=title, description=description) |
|
||||
user = create_random_user(db) |
|
||||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|
||||
description2 = random_lower_string() |
|
||||
item_update = ItemUpdate(description=description2) |
|
||||
item2 = crud.item.update(db=db, db_obj=item, obj_in=item_update) |
|
||||
assert item.id == item2.id |
|
||||
assert item.title == item2.title |
|
||||
assert item2.description == description2 |
|
||||
assert item.owner_id == item2.owner_id |
|
||||
|
|
||||
|
|
||||
def test_delete_item(db: Session) -> None: |
|
||||
title = random_lower_string() |
|
||||
description = random_lower_string() |
|
||||
item_in = ItemCreate(title=title, description=description) |
|
||||
user = create_random_user(db) |
|
||||
item = crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=user.id) |
|
||||
item2 = crud.item.remove(db=db, id=item.id) |
|
||||
item3 = crud.item.get(db=db, id=item.id) |
|
||||
assert item3 is None |
|
||||
assert item2.id == item.id |
|
||||
assert item2.title == title |
|
||||
assert item2.description == description |
|
||||
assert item2.owner_id == user.id |
|
@ -1,16 +1,16 @@ |
|||||
from sqlalchemy.orm import Session |
from sqlmodel import Session |
||||
|
|
||||
from app import crud, models |
from app import crud |
||||
from app.schemas.item import ItemCreate |
from app.models import Item, ItemCreate |
||||
from app.tests.utils.user import create_random_user |
from app.tests.utils.user import create_random_user |
||||
from app.tests.utils.utils import random_lower_string |
from app.tests.utils.utils import random_lower_string |
||||
|
|
||||
|
|
||||
def create_random_item(db: Session, *, owner_id: int | None = None) -> models.Item: |
def create_random_item(db: Session) -> Item: |
||||
if owner_id is None: |
user = create_random_user(db) |
||||
user = create_random_user(db) |
owner_id = user.id |
||||
owner_id = user.id |
assert owner_id is not None |
||||
title = random_lower_string() |
title = random_lower_string() |
||||
description = random_lower_string() |
description = random_lower_string() |
||||
item_in = ItemCreate(title=title, description=description, id=id) |
item_in = ItemCreate(title=title, description=description) |
||||
return crud.item.create_with_owner(db=db, obj_in=item_in, owner_id=owner_id) |
return crud.create_item(session=db, item_in=item_in, owner_id=owner_id) |
||||
|
Loading…
Reference in new issue